栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch之expand repeat

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch之expand repeat

1. expand
tensor.expand(*sizes)
1
expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上。

参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1。

expand函数可能导致原始张量的升维,其作用在张量前面的维度上,因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。

另一个值得注意的点是:expand函数并不会重新分配内存,返回结果仅仅是原始张量上的一个视图。

下面为几个简单的示例:

import torch
a = tensor([1, 0, 2])
b = a.expand(2, -1)   # 第一个维度为升维,第二个维度保持原阳
# b为   tensor([[1, 0, 2],  [1, 0, 2]])

a = torch.tensor([[1], [0], [2]])
b = a.expand(-1, 2)   # 保持第一个维度,第二个维度只有一个元素,可扩展
# b为  tensor([[1, 1],
#              [0, 0],
#              [2, 2]])

前文提及expand仅能作用于单数维,那对于非单数维的拓展,那就需要借助于repeat函数了。

tensor.repeat(*sizes)
1
参数*sizes指定了原始张量在各维度上复制的次数。整个原始张量作为一个整体进行复制,这与Numpy中的repeat函数截然不同,而更接近于tile函数的效果。

与expand不同,repeat函数会真正的复制数据并存放于内存中。

下面是一个简单的例子:

import torch
a = torch.tensor([1, 0, 2])
b = a.repeat(3,2)  # 在轴0上复制3份,在轴1上复制2份
# b为 tensor([[1, 0, 2, 1, 0, 2],
#        [1, 0, 2, 1, 0, 2],
#        [1, 0, 2, 1, 0, 2]])

  在轴0上复制3份,在轴1上复制2份的意思同样可以理解为,在列上复制三份,在每一行上复制2份。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/656579.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号