这就是collect_nd的目的!
def extract_axis_1(data, ind): """ Get specified elements along the first axis of tensor. :param data: Tensorflow tensor that will be subsetted. :param ind: Indices to take (one for each element along axis 0 of data). :return: Subsetted tensor. """ batch_range = tf.range(tf.shape(data)[0]) indices = tf.stack([batch_range, ind], axis=1) res = tf.gather_nd(data, indices) return res
在您的情况下:
output = extract_axis_1(output, lengths - 1)
现在
output是维的张量
[batch_size, num_cells]。



