栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch中tensor的直接赋值与clone()、numpy()

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch中tensor的直接赋值与clone()、numpy()

目录

1. 直接赋值

2. 使用clone()函数

3. 使用numpy()

4. 关于使用

4.1 比如我们想保存过程值时,使用直接保存,可以看到列表中的值是完全错误的

使用clone(),即可得到想要的结果

5. 总结


1. 直接赋值
import torch
t1 = torch.ones(3, 3)
t2 = t1
t1[:, 2] = 2
print(f"id(t1) = id(t1)")
print(f"id(t2) = id(t2)")
print(f"t1 = {t1}")
print(f"t2 = {t2}")

输出:

可以看到t1和t2的地址是一样的,改变t1的值,t1的值也随之改变

id(t1) = 1955897835832
id(t2) = 1955897835832
t1 = tensor([[1., 1., 2.],
            [1., 1., 2.],
            [1., 1., 2.]])
t2 = tensor([[1., 1., 2.],
            [1., 1., 2.],
            [1., 1., 2.]])

2. 使用clone()函数
import torch
t1 = torch.ones(3, 3)
t2 = t1.clone()
t1[:, 2] = 2
print(f"id(t1) = {id(t1)}")
print(f"id(t2) = {id(t2)}")
print(f"t1 = {t1}")
print(f"t2 = {t2}")

输出:

可以看到t2的地址改变了,改变t1的值,t2的值不受影响

id(t1) = 2032020259128
id(t2) = 2032020258968
t1 = tensor([[1., 1., 2.],
            [1., 1., 2.],
            [1., 1., 2.]])
t2 = tensor([[1., 1., 1.],
            [1., 1., 1.],
            [1., 1., 1.]])

3. 使用numpy()
import torch
t1 = torch.ones(3, 3)
t2 = t1.numpy()
t1[:, 2] = 2
print(f"id(t1) = {id(t1)}")
print(f"id(t2) = {id(t2)}")
print(f"t1 = {t1}")
print(f"t2 = {t2}")

输出:

可以看到t2的地址改变了,但改变t1的值时仍会改变,表明t2和t1的值取自一个位置

id(t1) = 2567195427048
id(t2) = 2567195426864
t1 = tensor([[1., 1., 2.],
             [1., 1., 2.],
             [1., 1., 2.]])
t2 = [[1. 1. 2.]
      [1. 1. 2.]
      [1. 1. 2.]]

4. 关于使用

4.1 比如我们想保存过程值时,使用直接保存,可以看到列表中的值是完全错误的
import torch
t1 = torch.ones(1)
list = []
for i in range(10):
    t1[0] = i
    list.append(t1[0]) # list.append(t1[0].numpy()) 会得到一样的结果
    print(list)
print(t1)
print(list)
-------------------------输出------------------------------
[tensor(0.)]
[tensor(1.), tensor(1.)]
[tensor(2.), tensor(2.), tensor(2.)]
[tensor(3.), tensor(3.), tensor(3.), tensor(3.)]
[tensor(4.), tensor(4.), tensor(4.), tensor(4.), tensor(4.)]
[tensor(5.), tensor(5.), tensor(5.), tensor(5.), tensor(5.), tensor(5.)]
[tensor(6.), tensor(6.), tensor(6.), tensor(6.), tensor(6.), tensor(6.), tensor(6.)]
[tensor(7.), tensor(7.), tensor(7.), tensor(7.), tensor(7.), tensor(7.), tensor(7.), tensor(7.)]
[tensor(8.), tensor(8.), tensor(8.), tensor(8.), tensor(8.), tensor(8.), tensor(8.), tensor(8.), tensor(8.)]
[tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.)]
tensor([9.])
[tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.), tensor(9.)]

4.2 使用clone(),即可得到想要的结果
[tensor(0.)]
[tensor(0.), tensor(1.)]
[tensor(0.), tensor(1.), tensor(2.)]
[tensor(0.), tensor(1.), tensor(2.), tensor(3.)]
[tensor(0.), tensor(1.), tensor(2.), tensor(3.), tensor(4.)]
[tensor(0.), tensor(1.), tensor(2.), tensor(3.), tensor(4.), tensor(5.)]
[tensor(0.), tensor(1.), tensor(2.), tensor(3.), tensor(4.), tensor(5.), tensor(6.)]
[tensor(0.), tensor(1.), tensor(2.), tensor(3.), tensor(4.), tensor(5.), tensor(6.), tensor(7.)]
[tensor(0.), tensor(1.), tensor(2.), tensor(3.), tensor(4.), tensor(5.), tensor(6.), tensor(7.), tensor(8.)]
[tensor(0.), tensor(1.), tensor(2.), tensor(3.), tensor(4.), tensor(5.), tensor(6.), tensor(7.), tensor(8.), tensor(9.)]
tensor([9.])
[tensor(0.), tensor(1.), tensor(2.), tensor(3.), tensor(4.), tensor(5.), tensor(6.), tensor(7.), tensor(8.), tensor(9.)]

5. 总结

当我们想要存储或者是使用中间值时,需要保留clone()的值,不然会得到完全错误的结果

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/866829.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号