栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch中的torch.tensor.repeat以及torch.tensor.expand用法

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch中的torch.tensor.repeat以及torch.tensor.expand用法

torch.tensor.expand

先看招

import torch
x = torch.tensor([[1], [2], [3]])
print(x.size())
print(x.expand(3, 4))
print(x.expand(-1, 4))   # -1 means not changing the size of that dimension,所以原来是3,现在仍然是3,故和上述等价

说白了,就是复制!!!怎么复制呢?原来是[3,1],现在要变成[3,4],所以是对原tensor中第二个维度里面的数进行复制!!

要求:
被扩张的那个维度必须只有一个数!!也就是说size必须是1!!,所以原tensor必须size是[3,1],不可以是[3,2],否则报错。即:

tensor with singleton dimensions expanded to a larger size.

torch.tensor.repeat

同样都是复制,这个比上面这个好用。
上面这个功能可以如下实现:

x = torch.tensor([[1], [2], [3]])
print(x.size())
print(x.repeat(1, 4))#用法不一样的地方,不变的地方用1表示,而不是-1.


而且其不需要扩张的维度严格要求为1,例如可以是[3,2],例如:

x = torch.tensor([[1,1], [2,1], [3,0]])
print(x.size())
print(x.repeat(1, 4))


这才是真正的复制啊。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/326029.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号