栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

torch.scatter函数详解

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

torch.scatter函数详解

#torch.scatter函数官方解释

scatter(output, dim, index, src) → Tensor

Writes all values from the tensor src into self at the indices specified in the index tensor. For each value in src, its output index is specified by its index in src for dimension != dim and by the corresponding value in index for dimension = dim.

For a 3-D tensor, self is updated as:

  • output[index[i][j][k]][j][k] = src[i][j][k]  # if dim == 0
  • output[i][index[i][j][k]][k] = src[i][j][k]  # if dim == 1
  • output[i][j][index[i][j][k]] = src[i][j][k]  # if dim == 2

This is the reverse operation of the manner described in gather().

self, index and src (if it is a Tensor) should have same number of dimensions. It is also required that index.size(d) <= src.size(d) for all dimensions d, and that index.size(d) <= self.size(d) for all dimensions d != dim.

Moreover, as for gather(), the values of index must be between 0 and self.size(dim) - 1 inclusive, and all values in a row along the specified dimension dim must be unique.

Parameters

  • dim (int) – the axis along which to index
  • index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
  • src (Tensor) – the source element(s) to scatter, incase value is not specified
  • value (float) – the source element(s) to scatter, incase src is not specified

总结:scatter函数就是把src数组中的数据重新分配到output数组当中,index数组中表示了要把src数组中的数据分配到output数组中的位置,若未指定,则填充0.

#通过例子理解函数

import torch
 
input = torch.randn(2, 4)
print(input)
output = torch.zeros(2, 5)
index = torch.tensor([[3, 1, 2, 0], [1, 2, 0, 3]])
output = output.scatter(1, index, input)
print(output)

#得到输出
tensor([[-0.2558, -1.8930, -0.7831,  0.6100],
        [ 0.3246,  2.1289,  0.5887,  1.5588]])

tensor([[ 0.6100, -1.8930, -0.7831, -0.2558,  0.0000],
        [ 0.5887,  0.3246,  2.1289,  1.5588,  0.0000]])

建议从input数组出发,结合官方给出的位置替换进行理解。

数据位置发生的变化都是在第1维上,第0维不变。若dim=0,则同理变换input第一维的下标。

  • input[0][0] = output[0][index[0][0]] = output[0][3]
  • input[0][1] = output[0][index[0][1]] = output[0][1]
  • input[0][2] = output[0][index[0][2]] = output[0][2]
  • input[0][3] = output[0][index[0][3]] = output[0][0]
  • Input[1][0] = output[1][index[1][0]] = output[1][1]
  • input[1][1] = output[1][index[1][1]] = output[1][2]
  • input[1][2] = output[1][index[1][2]] = output[1][0]
  • input[1][3] = output[1][index[1][3]] = output[1][3]

一般scatter用于生成onehot向量,如下所示:

index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

#输出
tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.]])

#如果input是一个数字的话,代表这用于分配到output的数字是多少。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/580755.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号