在torch._C._Tensorbase.py中,定义了scatter_(self, dim, index, src, reduce=None) -> Tensor方法,作用是将src的值写入index指定的self相关位置中。用一个三维张量举例如下,将src在坐标(i,j,k)下的所有值,写入self的相应位置,而self的位置坐标除了dim维度用index[i,j,k]代替以外,都不变:
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0,用index[i][j][k]替换i坐标 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1,用index[i][j][k]替换j坐标 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2,用index[i][j][k]替换k坐标
要求:
- self,index,src必须有相同的维数;
- index在任意维度的size必须小于等于self和src对应维度的size
- self和index中元素的类型必须一致,dtype
>>> x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
"""
理解一下:
self是一个shape为(3,5)的全零tensor;
index是一个shape为(2,5)的tensor;
x同index的shape相同,不相同也可。
dim=0,意味着index需要修改第0维坐标;
原始坐标为:00,01,02,03,04;10,11,12,13,14
更新的横坐标依次为:01200;20012
更新的纵坐标依次为:01234;01234
对应组合,更新坐标为:00,11,22,03,04;20,01,02,13,24
然后用x在原始坐标下的值填写到self更新后的坐标位置,将原始坐标和更新坐标对应来看。
具体来看:
x new_self
00 00
01 11
02 22
03 03
04 04
10 20
11 01
12 02
13 10
14 24
"""
图示上述例子:
>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
[ 0.0000, 0.0000, 0.0000, 1.2300]])
"""
理解一下:一个2*1的index_tensor(一个2维张量,两个维度的size分别是2和1,对应两个值为2和3),dim=1,需要修改的就是1维。
原来的坐标是00,10;修改后的坐标是02,13。
然后用目标值1.23去替换self中坐标02,13的值,得到上述结果。
"""
>>> z = torch.ones(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23, reduce='multiply')
>>> z
tensor([[1.0000, 1.0000, 1.2300, 1.0000],
[1.0000, 1.0000, 1.0000, 1.2300]])
"""
同上:用目标值找到self在更新坐标位置的值,乘以目标值1.23得到更新后的矩阵。
"""
类似于上述方法,在python中还包括scatter_add(dim, index, src) -> Tensor用于实现将src按照index位置累加到self上。
二、手动实现上述函数分为以下几个步骤:
- 将所有的坐标按照从上到下,从左到右的顺序存储到数组raw_index中;
- 按照dim和index修改原始坐标,得到新的坐标index_pos;
- self_tensor在index_pos位置的值要累加上other_tensor在raw_index位置的值
import torch
import numpy as np
from torch import Tensor
"""
@overload
def scatter_add(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
@overload
def scatter_add(self, dim: Union[str, ellipsis, None], index: Tensor, src: Tensor) -> Tensor: ...
def scatter_add_(self, dim: _int, index: Tensor, src: Tensor) -> Tensor: ...
对pytorch中的scatter_add函数的理解和简单测试:
# 参数:tensor,dim,index,tensor
# 返回:tensor
# 功能:将other_tensor的值累加到self_tensor的相应位置,用index_tensor对应位置的值替换掉self_tensor下标的dim维
# 举例:
self_tensor = [[1, 2], [3, 4]] shape=(2,2)
other_tensor = [[5, 6], [7, 8]] shape=(2,2)
index_tensor = [[0, 0], [1, 1]] shape=(2,2)
dim = 1
以上三个tensor的shape必须一致,下标为:[0,0] [0,1] [1,0] [1,1]
dim=1,那么,self_tensor的第1维下标由index_tensor表示,[0,0] [0,0] [1,1] [1,1]
则:
self_tensor[0,0] = 1 + 5 + 6 = 12
self_tensor[0,1] = 2
self_tensor[1,0] = 3
self_tensor[1,1] = 4 + 7 + 8 = 19
"""
def scatter_add(input_tensor: torch.Tensor, dim: int, index: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
# tensor的维数是不确定的,因此无法用for循环的方式
# 如果tensor是2维,那么dim=0或1,两层for循环,用other对self进行填充
# 如果tensor是3维,那么dim=0、1、2,需要三层for循环来遍历other
if input_tensor.dim() == 2:
for i in range(index_tensor.size()[0]):
for j in range(index_tensor.size()[1]):
if dim == 0: # self矩阵的第0维索引
self_tensor[index_tensor[i][j]][j] += other_tensor[i][j]
elif dim == 1: # self矩阵的第1维索引
self_tensor[i][index_tensor[i][j]] += other_tensor[i][j]
elif input_tensor.dim() == 3:
pass
return self_tensor
if __name__ == '__main__':
index_tensor = torch.tensor([[0, 0], [1, 1]])
print('index_tensor: n', index_tensor.dim())
self_tensor = torch.arange(1, 5).view(2, 2)
print('self_tensor: n', self_tensor)
other_tensor = torch.arange(5, 9).view(2, 2)
print('other_tensor: n', other_tensor)
dim = 1
for i in range(index_tensor.size()[0]):
for j in range(index_tensor.size()[1]):
replace_index = index_tensor[i][j]
print(i, j, replace_index)
if dim == 0:
# self矩阵的第0维索引
self_tensor[replace_index][j] += other_tensor[i][j]
elif dim == 1:
# self矩阵的第1维索引
self_tensor[i][replace_index] += other_tensor[i][j]
print(self_tensor)
index_tensor = torch.tensor([[0, 1], [1, 1]])
print('index_tensor: n', index_tensor)
self_tensor = torch.arange(0, 4).view(2, 2)
print('self_tensor: n', self_tensor)
other_tensor = torch.arange(5, 9).view(2, 2)
print('other_tensor: n', other_tensor)
self_tensor.scatter_add_(dim=0, index=index_tensor, src=other_tensor)
print(self_tensor)
四、其他语言实现
五、小tips
1 python的多维数组下标存取
import numpy as np a = np.arange(3 * 4 * 5).reshape((3, 4, 5)) print(a) """ [[[ 0 1 2 3 4] [ 5 6 7 8 9] [10 11 12 13 14] [15 16 17 18 19]] [[20 21 22 23 24] [25 26 27 28 29] [30 31 32 33 34] [35 36 37 38 39]] [[40 41 42 43 44] [45 46 47 48 49] [50 51 52 53 54] [55 56 57 58 59]]] """ # 三维数组,下标,举例(1,2,0) # 第一种方式,所有语言共同的读取方式,一般通过多层循环嵌套生成不同维度下标来读取 print(a[1][2][0]) # 30 # 第二种方式,python独有,将需要的三个维度下标位置直接放入中括号中,就可以读取; # 适合于不同维度的数组,通过已知的下标位置读取值 print(a[1, 2, 0]) # 30 print(a[(1, 2, 0)]) # 30 # 第三种方式,一般将下标位置提前用list存储,只能得到多个list组合的数组;要想达到上述要求,可以将list转为tuple pos = [1, 2, 0] print(a[pos]) # 正解:a[tuple(pos)] """ [[[20 21 22 23 24] [25 26 27 28 29] [30 31 32 33 34] [35 36 37 38 39]] [[40 41 42 43 44] [45 46 47 48 49] [50 51 52 53 54] [55 56 57 58 59]] [[ 0 1 2 3 4] [ 5 6 7 8 9] [10 11 12 13 14] [15 16 17 18 19]]] """2 python深拷贝
按照dim修改坐标位置时,需要用到深拷贝,可以参考这篇博文。 Java基础-Cloneable接口,深浅拷贝【附python,C++深拷贝、浅拷贝】
参考:- 源码
- 官网对于scatter_add_的解释
- PyTorch扩展自定义PyThon/C++(CUDA)算子的若干方法总结
- python多维数组的下标存取



