- gather函数的输出规则
- 第一条规则
- 第二条规则
- gather函数内部的代码机理推测
- 代码示例
- 输出结果
- 参考文献
事先声明:本文只会对二维张量的gather操作进行介绍,三维张量的gather操作规则在csdn上的博文屡见不鲜。本文的解释是从个人的理解出发,相信解释也会对理解三维张量的操作规则起到触类旁通的作用。 gather函数的输出规则
o
u
t
[
i
]
[
j
]
=
i
n
p
u
t
[
i
n
d
e
x
[
i
]
[
j
]
]
[
j
]
,
i
f
d
i
m
=
=
0
out [i] [j] = input [index [i] [j] ] [j], if{ }dim == 0
out[i][j]=input[index[i][j]][j],if dim==0
o
u
t
[
i
]
[
j
]
=
i
n
p
u
t
[
i
]
[
i
n
d
e
x
[
i
]
[
j
]
]
,
i
f
d
i
m
=
=
1
out [i] [j] = input [i] [index [i] [j] ], if{ }dim == 1
out[i][j]=input[i][index[i][j]],if dim==1
从行的角度出发,输入的index张量按照如上规则,取出对应的输入张量的元素。
例如:一个
2
∗
3
2*3
2∗3的维度的张量,index为
[
0
,
1
,
1
]
[0, 1, 1]
[0,1,1],取
d
i
m
=
0
dim=0
dim=0,根据规则,外层循环为变量
i
i
i,内层循环为变量
j
j
j,且
i
i
n
r
a
n
g
e
(
0
,
2
)
;
j
i
n
r
a
n
g
e
(
0
,
3
)
i { }in{ } range(0, 2); j{ } in{ } range(0, 3)
i in range(0,2);j in range(0,3)。
代入
i
=
0
,
j
=
1
i=0,{ }j=1
i=0, j=1,得到:
o
u
t
[
0
]
[
1
]
=
i
n
p
u
t
[
i
n
d
e
x
[
0
]
[
1
]
]
[
1
]
out[0][1]=input[index[0][1]][1]
out[0][1]=input[index[0][1]][1]
o
u
t
[
0
]
[
1
]
=
i
n
p
u
t
[
1
]
[
1
]
out[0][1]=input[1][1]
out[0][1]=input[1][1]
即:该输出元素为输入的
2
∗
3
2*3
2∗3维度张量的第1行第1列元素。且该元素在输出张量中处在第0行第1列的位置。
如下表所示:
| 0 | 1 | 2 | |
|---|---|---|---|
| 0 | |||
| 1 | this element |
其中, 0 , 1 , 2 {0, 1, 2} 0,1,2代表列标号, 0 , 1 {0, 1} 0,1代表行标号。
第二条规则从列的角度出发,输入的index张量按照如上规则,取出对应的输入张量的元素。
例如:一个
2
∗
3
2*3
2∗3的维度的张量,index为
[
[
0
,
1
,
1
]
,
[
1
,
1
,
1
]
]
[[0, 1, 1],[1, 1, 1]]
[[0,1,1],[1,1,1]],取
d
i
m
=
1
dim=1
dim=1,根据规则,外层循环为变量
i
i
i,内层循环为变量
j
j
j,且
i
i
n
r
a
n
g
e
(
0
,
2
)
;
j
i
n
r
a
n
g
e
(
0
,
3
)
i { }in{ } range(0, 2); j{ } in{ } range(0, 3)
i in range(0,2);j in range(0,3)。
代入
i
=
1
,
j
=
1
i=1,{ }j=1
i=1, j=1,得到:
o
u
t
[
1
]
[
1
]
=
i
n
p
u
t
[
1
]
[
i
n
d
e
x
[
1
]
[
1
]
]
out[1][1]=input[1][index[1][1]]
out[1][1]=input[1][index[1][1]]
o
u
t
[
1
]
[
1
]
=
i
n
p
u
t
[
1
]
[
1
]
out[1][1]=input[1][1]
out[1][1]=input[1][1]
即:该输出元素为输入的
2
∗
3
2*3
2∗3维度张量的第1行第1列元素,且该元素在输出张量中处在第1行第1列的位置。
如下表所示:
| 0 | 1 | 2 | |
|---|---|---|---|
| 0 | |||
| 1 | this element |
其中, 0 , 1 , 2 {0, 1, 2} 0,1,2代表列标号, 0 , 1 {0, 1} 0,1代表行标号。
gather函数内部的代码机理推测声明:下述代码仅针对原理部分编写,距离函数内部真实情况仍存在较大差距,且下述代码的严谨性不够,故仅供理解gather的核心规则。
def gather(input, dim, index): # 这里的dim要求取0或1 out = [] m = input.size()[0] # size函数是torch的方法 n = input.size()[1] for i in range(m): for j in range(n): if dim == 0: out [i] [j] = input [index [i] [j] ] [j] if dim == 1: out [i] [j] = input [i] [index [i] [j] ] return out代码示例
与上一篇博文内容相同,这里再次展示一遍。
import torch # 设置一个随机种子 torch.manual_seed(100) # 生成一个形状为2*3的矩阵 x = torch.randn(2, 3) print(x) # 获取指定索引对应的值 index = torch.LongTensor([[0, 1, 1]]) print(torch.gather(x, 0, index)) index = torch.LongTensor([[0, 1, 1], [1, 1, 1]]) a = torch.gather(x, 1, index) print(a)输出结果 参考文献
吴茂贵,郁明敏,杨本法,李涛,张粤磊. Python深度学习(基于Pytorch). 北京:机械工业出版社,2019.



