栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch.gather()函数深入理解(dim=1,2,3三种维度分析)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch.gather()函数深入理解(dim=1,2,3三种维度分析)

首先我要吐槽torch.gather()函数的官方文档 请问它在说个啥 后来根据csdn以及自己的学习 总结出torch.gather()的用法

首先 给出torch.gather()中的几个参数

torch.gather(input, dim, index, out None, sparse_grad False) → Tensor

常用的就是input,dim,index三个参数

input: 你要输入的torch.tensor() dim: 要处理的维度 一个[ ]表示一个维度 比如[ [ 2,3 ] ]中的2和3就是在第二维,dim可以取0 1 2 index: 必须为torch.LongTensor()的类型 且维度大小必须和input相同 index中每一个值表示input在dim维中的下标 下标从0开始

说了这么多估计也没说明白 正常正常 先上几个例子自己理解理解

1、dim 1时的情况
input torch.arange(15).view(3,5)
print( input:n ,input)
index1 torch.tensor([
 [1, 0],
 [0, 0],
 [1, 2]])
print( index:n ,index)
print( dim 1时:n ,torch.gather(input,dim 1,index index1))

结果为

input:
 tensor([[ 0, 1, 2, 3, 4],
 [ 5, 6, 7, 8, 9],
 [10, 11, 12, 13, 14]])
index:
 tensor([[1, 0],
 [0, 0],
 [1, 2]])
dim 1时:
 tensor([[ 1, 0],
 [ 5, 5],
 [11, 12]])

dim 1表示取第二维 也就是第二个中括号中的元素进行处理 仔细观察index

index中[1,0]中的1表示在input中第二维下标为1的元素 也就是1index中[1,0]中的0表示在input中第二维下标为1的元素 也就是0index中[1,2]中的1表示在input中第二维第三组值下标为1的元素。也就是11index中[1,2]中的2表示在input中第二维第三组值下标为2的元素。也就是12

注意 dim 1时 index中组的个数要与input组的个数相同

2、dim 0时的情况
input torch.arange(15).view(3,5)
print( input:n ,input)
index1 torch.tensor([
 [1,0,0,0,0],
 [0,0,1,2,1],
 [1,2,0,0,0]])
print( index:n ,index1)
print( dim 0时:n ,torch.gather(input,dim 0,index index1))

输出结果

input:
 tensor([[ 0, 1, 2, 3, 4],
 [ 5, 6, 7, 8, 9],
 [10, 11, 12, 13, 14]])
index:
 tensor([[1, 0, 0, 0, 0],
 [0, 0, 1, 2, 1]])
dim 0时:
 tensor([[ 5, 1, 2, 3, 4],
 [ 0, 1, 7, 13, 9]])

当dim 0时 表示在第一维中检索下标 input第一维度的数据可以看作

[0,5,10],[1,6,11],[2,7,12],[3,8,13],[4,0,14]

那么在index中的[1,0,0,0,0]中的1表示在[0,5,10]中取下标为1的元素 也就是5。后面可以依次取出并集合到一个tensor中去。

注意 dim 0时 index每一组中的元素个数要与input中的元素个数相同 也就是都为5个。

3、dim 3时

相信对torch.gather()有一定的了解 那么在下面举例dim 3的情况

input torch.tensor([[
 [1,2,3],
 [4,5,6],
 [7,8,9]]
index1 torch.tensor([[
 [0,0],
 [0,0],
 [0,0]
print( input:n ,input)
print( index:n ,index1)
print( dim 3时:n ,torch.gather(test,dim 2,index index1))

结果为

input:
 tensor([[[1, 2, 3],
 [4, 5, 6],
 [7, 8, 9]]])
index:
 tensor([[[0, 0],
 [0, 0],
 [0, 0]]])
dim 3时:
 tensor([[[1, 1],
 [4, 4],
 [7, 7]]])

以上就是有关gather()的笔记 不知道我是否有讲清楚 有帮助的点个赞吧~

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/267717.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号