栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch tensor按广播赋值 scatter

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch tensor按广播赋值 scatter

普通广播
>>> import torch
>>> a = torch.tensor([[1,2,3],[4,5,6]])
# 和a shape相同,但是用0填充
>>> b = torch.full_like(a,0)
>>> c = torch.tensor([[0,0,1],[1,0,1]])
# 赋值索引
>>> c[:,0]
tensor([0, 1])
# 赋值语句:使用广播机制进行赋值
>>> b[range(n),c[:,0]] = 1
>>> b
tensor([[1, 0, 0],
        [0, 1, 0]])

为什么会出现这样的结果?

赋值语句的意思是:

    range(n)表示对b的所有行进行赋值操作c[:,0]] 表示执行赋值操作的b的列索引,[0, 1] 表示第一行对索引为0的列进行操作(赋值为1);第二行对索引为1的列进行操作(赋值为1)最右边的1表示对应索引位置所赋的值
scatter函数
import torch
label = torch.zeros(3, 6) #首先生成一个全零的多维数组
print("label:",label)
a = torch.ones(3,5)
 
b = [[0,1,2],[0,1,3],[1,2,3]]
#这里需要解释的是,b的行数要小于等于label的行数,列数要小于等于a的列数
print(a)
label.scatter_(1,torch.LongTensor(b),a) 
#参数解释:‘1’:需要赋值的维度,是label的维度;‘torch.LongTensor(b)’:需要赋值的索引;‘a’:要赋的值
print("new_label: ",label)

label: 
tensor([[0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0.]])
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])
new_label:  
tensor([[1., 1., 1., 0., 0., 0.],
        [1., 1., 0., 1., 0., 0.],
        [0., 1., 1., 1., 0., 0.]])

举例:

>>> b = torch.full_like(a,0)
>>> b
tensor([[0, 0, 0],
        [0, 0, 0]])
>>> c = torch.tensor([[0,0],[1,0]])
>>> c
tensor([[0, 0],
        [1, 0]])
# 1表示对b的列进行赋值,以c的每一行的值作为b的列索引,一行一行地进行赋值
# c第一行 [0,0] 表示分别将b的 第一行 第0列、第0列 元素赋值为1 (重复操作了)
# c第二行 [1,0] 表示 将b的 第1列、第0列 元素赋值为1 (逆序了)
# 上面的这两个赋值操作其实有重复的、逆序的
>>> b.scatter_(1,torch.LongTensor(c),1)
>>> b
tensor([[1, 0, 0],
        [1, 1, 0]])

感谢!https://blog.csdn.net/qq_41368074/article/details/106986753

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/769522.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号