Gather slices from params axis according to indices.
tf.gather(params, indices, validate_indices=None, axis=None, batch_dims=0, name=None)
Example:
# data: [classes, students, subjects] data = tf.ones([4,35,8]) print(data.shape) # TensorShape([4, 35, 8]) # sample several classes data = tf.gather(data, axis=0, indices=[2,3]) print(data.shape) # TensorShape([2, 35, 8])
tf.gather_nd
Gather slices from params into a Tensor with shape specified by indices.
tf.gather_nd(params, indices, batch_dims=0, name=None)
Example:
# data: [classes, students, subjects] data = tf.ones([4,35,8]) print(data.shape) # TensorShape([4, 35, 8]) # sample several (classes and students) # for instance: [class1_student1, class2_studnet2, class3_student3, class4_student4] data = tf.gather_nd(data, [[0,0],[1,1],[2,2],[3,3]]) print(data.shape) # TensorShape([4, 8])



