import torch import numpy as np a = np.arange(1,9).reshape(2,2,2) a=torch.tensor(a) id=torch.tensor(np.ones([2,3,4]),dtype=int) id[0][0][1] = 0 id[1][2][0] = 0 a = a.unsqueeze(-2).expand(2,2,4,2) id=id.unsqueeze(-1).expand(2,3,4,2) c = torch.gather(a,1,id) pass



