首先我要吐槽torch.gather()函数的官方文档 请问它在说个啥 后来根据csdn以及自己的学习 总结出torch.gather()的用法
首先 给出torch.gather()中的几个参数
torch.gather(input, dim, index, out None, sparse_grad False) → Tensor
常用的就是input,dim,index三个参数
input: 你要输入的torch.tensor() dim: 要处理的维度 一个[ ]表示一个维度 比如[ [ 2,3 ] ]中的2和3就是在第二维,dim可以取0 1 2 index: 必须为torch.LongTensor()的类型 且维度大小必须和input相同 index中每一个值表示input在dim维中的下标 下标从0开始说了这么多估计也没说明白 正常正常 先上几个例子自己理解理解
1、dim 1时的情况input torch.arange(15).view(3,5) print( input:n ,input) index1 torch.tensor([ [1, 0], [0, 0], [1, 2]]) print( index:n ,index) print( dim 1时:n ,torch.gather(input,dim 1,index index1))
结果为
input: tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]) index: tensor([[1, 0], [0, 0], [1, 2]]) dim 1时: tensor([[ 1, 0], [ 5, 5], [11, 12]])
dim 1表示取第二维 也就是第二个中括号中的元素进行处理 仔细观察index
index中[1,0]中的1表示在input中第二维下标为1的元素 也就是1index中[1,0]中的0表示在input中第二维下标为1的元素 也就是0index中[1,2]中的1表示在input中第二维第三组值下标为1的元素。也就是11index中[1,2]中的2表示在input中第二维第三组值下标为2的元素。也就是12注意 dim 1时 index中组的个数要与input组的个数相同
2、dim 0时的情况input torch.arange(15).view(3,5) print( input:n ,input) index1 torch.tensor([ [1,0,0,0,0], [0,0,1,2,1], [1,2,0,0,0]]) print( index:n ,index1) print( dim 0时:n ,torch.gather(input,dim 0,index index1))
输出结果
input: tensor([[ 0, 1, 2, 3, 4], [ 5, 6, 7, 8, 9], [10, 11, 12, 13, 14]]) index: tensor([[1, 0, 0, 0, 0], [0, 0, 1, 2, 1]]) dim 0时: tensor([[ 5, 1, 2, 3, 4], [ 0, 1, 7, 13, 9]])
当dim 0时 表示在第一维中检索下标 input第一维度的数据可以看作
[0,5,10],[1,6,11],[2,7,12],[3,8,13],[4,0,14]
那么在index中的[1,0,0,0,0]中的1表示在[0,5,10]中取下标为1的元素 也就是5。后面可以依次取出并集合到一个tensor中去。
注意 dim 0时 index每一组中的元素个数要与input中的元素个数相同 也就是都为5个。
3、dim 3时相信对torch.gather()有一定的了解 那么在下面举例dim 3的情况
input torch.tensor([[ [1,2,3], [4,5,6], [7,8,9]] index1 torch.tensor([[ [0,0], [0,0], [0,0] print( input:n ,input) print( index:n ,index1) print( dim 3时:n ,torch.gather(test,dim 2,index index1))
结果为
input: tensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]]]) index: tensor([[[0, 0], [0, 0], [0, 0]]]) dim 3时: tensor([[[1, 1], [4, 4], [7, 7]]])
以上就是有关gather()的笔记 不知道我是否有讲清楚 有帮助的点个赞吧~



