看values的值 讲一下第一行元素[9,7,6,7]是如何得来的
因为dim 0所以要从第0维来看 将数据分成3份 分别是
1. [[9, 1, 3, 7], [8, 7, 8, 2], [2, 1, 9, 3], [7, 4, 2, 3]] 2. [[1, 1, 6, 4], [9, 3, 9, 3], [9, 6, 8, 9], [7, 1, 2, 1]] 3. [[9, 7, 2, 7], [2, 9, 9, 8], [3, 4, 8, 8], [2, 6, 5, 8]]
要以这三个tensor为单位进行topk的筛选 首先比较每一个tensor的第一行 因为参数k为2 所以就要找到这3组元素中的最大值和次大值 作为最后的输出。因此最大值就是[9 7 6 7] 次大值为:[9 1 3 7]这样就完成了筛选。索引值也就是当前位置处的元素 是来自于这三个元素中的哪一个。我认为把这个看懂后面就可以迎刃而解 大家可以仔细理解一下不太懂的话也没关系 看完后面两个可能这个就懂了。
values1 , indices1 a.topk(2,dim 1) print(values1) print(indices1)
torch.Size([3, 2, 4]) tensor([[[9, 7, 9, 7], [8, 4, 8, 3]], [[9, 6, 9, 9], [9, 3, 8, 4]], [[9, 9, 9, 8], [3, 7, 8, 8]]]) torch.Size([3, 2, 4]) tensor([[[0, 1, 2, 0], [1, 3, 1, 2]], [[1, 2, 1, 2], [2, 1, 2, 0]], [[0, 1, 1, 2], [2, 0, 2, 3]]])
这个例子是dim 1时 类比于dim 0的情况。这里是对第一维进行筛选操作。需要注意的是这里第0维的三个元素是分开操作的。这里我提供一种我自己的理解思路大家借鉴。首先还是按照第0维将tensor分为3块
1. [[9, 1, 3, 7], [8, 7, 8, 2], [2, 1, 9, 3], [7, 4, 2, 3]] 2. [[1, 1, 6, 4], [9, 3, 9, 3], [9, 6, 8, 9], [7, 1, 2, 1]] 3. [[9, 7, 2, 7], [2, 9, 9, 8], [3, 4, 8, 8], [2, 6, 5, 8]]
这里每一块中的第0维就是总体tensor的第一维 从第0维来看就是4个14的向量 因此就是对这4向量取最大值和次大值。也就是在这个44的张量中选出对应位置的最大值和次大值。例如第一块中筛选出的结果就是[9 7 9 7]和[8 4 8 3]其他同理 索引值表示当前位置处的值是来自哪一个向量。
values2 , indices2 a.topk(2,dim 2) print(values2.shape) print(values2) print(indices2.shape) print(indices2)
torch.Size([3, 4, 2]) tensor([[[9, 7], [8, 8], [9, 3], [7, 4]], [[6, 4], [9, 9], [9, 9], [7, 2]], [[9, 7], [9, 9], [8, 8], [8, 6]]]) torch.Size([3, 4, 2]) tensor([[[0, 3], [0, 2], [2, 3], [0, 1]], [[2, 3], [0, 2], [3, 0], [0, 2]], [[0, 1], [1, 2], [2, 3], [3, 1]]])
类比前两种情况的思考方式 这里的操作就是对整个张量最内层做的操作 也就是整体张量形状(3,4,4)中的4这个4就是最内层每一个一维向量中的4个元素 取对应的最大值和次大值 应该也容易理解。大家可以对比着三种情况的输入输出加以理解。
另外 k参数默认是最后一维
然后研究一下lagest参数
直接用最后一维
values2 , indices2 a.topk(2,dim 2,largest False) print(values2.shape) print(values2) print(indices2.shape) print(indices2)
torch.Size([3, 4, 2]) tensor([[[1, 3], [2, 7], [1, 2], [2, 3]], [[1, 1], [3, 3], [6, 8], [1, 1]], [[2, 7], [2, 8], [3, 4], [2, 5]]]) torch.Size([3, 4, 2]) tensor([[[1, 2], [3, 1], [1, 0], [2, 3]], [[1, 0], [1, 3], [1, 2], [1, 3]], [[2, 3], [0, 3], [0, 1], [0, 2]]])
很明显 只是取最小和次小
下面是sorted参数
依然用dim 2进行测试
values2 , indices2 a.topk(2,dim 2,sorted False) print(values2.shape) print(values2) print(indices2.shape) print(indices2)
torch.Size([3, 4, 2]) tensor([[[9, 7], [8, 8], [9, 3], [7, 4]], [[6, 4], [9, 9], [9, 9], [7, 2]], [[9, 7], [9, 9], [8, 8], [8, 6]]]) torch.Size([3, 4, 2]) tensor([[[0, 3], [0, 2], [2, 3], [0, 1]], [[2, 3], [0, 2], [3, 0], [0, 2]], [[0, 1], [1, 2], [2, 3], [3, 1]]])
这里和不sorted True好像并没有区别 不知道要怎么理解 网上也没找到类似的解释 希望有知道的大佬可以多多指教
如有错误请多指正


