栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

torch.scatter()

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

torch.scatter()

功能:把source矩阵的值,根据索引“撒”到目标矩阵中。

格式:torch.scatter(dim, index, source)

dim:维度

index:对该维度的索引

source:数据来源

例子:

x = torch.rand(2, 5)

#tensor([[0.1940, 0.3340, 0.8184, 0.4269, 0.5945],
#        [0.2078, 0.5978, 0.0074, 0.0943, 0.0266]])

torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

#tensor([[0.1940, 0.5978, 0.0074, 0.4269, 0.5945],
#        [0.0000, 0.3340, 0.0000, 0.0943, 0.0000],
#        [0.2078, 0.0000, 0.8184, 0.0000, 0.0266]])

解释:

 数据源头是x,x有10个值,现在把这10个值撒到[3, 5]的矩阵中,那么每个值都要有一个新的位置索引,这个新的索引由index指定。

首先,有10个坑位:

然后把index写进去,dim=0,表示index代表第0维;

0       1200
20012

最后,按照自然顺序补充第二维索引

0 01 12 20 30 4
2 00 10 21 32 4

以第一行22为例,表示把x中[0, 2]的数据【0.8184】,路由到目标矩阵的[2, 2]位置。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/339690.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号