栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch squeeze鍑芥暟_unsqueeze pytorch?

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch squeeze鍑芥暟_unsqueeze pytorch?

顾名思义:unsqueeze,扩展维度,返回一个新的张量,对输入的既定位置插入维度 1

                  squeeze,压缩维度,将输入张量形状中的1 去除并返回。

torch.unsqueeze(input, dim)

torch.squeeze(input, dim)

tensor (Tensor) – 输入张量dim (int) – 插入/消除 维度的索引

以下用一个二维张量进行举例:

压缩维度仅对(0,1)索引进行示例,(-1,-2)原理类似

import torch

x = torch.Tensor([[1, 2, 3, 4],
                 [5,6,7,8]])  
print('#' * 50)
print(x)  
print(x.size())  
print(x.dim())  

##########
print('#' * 50)
print(torch.unsqueeze(x, 0))  
print(torch.unsqueeze(x, 0).size())  
print(torch.unsqueeze(x, 0).dim())  
m=torch.unsqueeze(x, 0)
print(m.squeeze(0))
n=m.squeeze(0)
print(n.size())
print(n.dim())

##########
print('#' * 50)
print(torch.unsqueeze(x, 1))
print(torch.unsqueeze(x, 1).size())  
print(torch.unsqueeze(x, 1).dim())  
a=torch.unsqueeze(x, 1)
print(a.squeeze(1))
b=a.squeeze(1)
print(b.size())
print(b.dim())

##########
print('#' * 50)
print(torch.unsqueeze(x, -1))
print(torch.unsqueeze(x, -1).size())  
print(torch.unsqueeze(x, 1).dim())

##########
print('#' * 50)
print(torch.unsqueeze(x, -2))  
print(torch.unsqueeze(x, -2).size())  
print(torch.unsqueeze(x, -2).dim())  

相应结果:

##################################################
tensor([[1., 2., 3., 4.], 
        [5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1., 2., 3., 4.],
         [5., 6., 7., 8.]]])
torch.Size([1, 2, 4])
3
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1., 2., 3., 4.]],

        [[5., 6., 7., 8.]]])
torch.Size([2, 1, 4])
3
tensor([[1., 2., 3., 4.],
        [5., 6., 7., 8.]])
torch.Size([2, 4])
2
##################################################
tensor([[[1.],
         [2.],
         [3.],
         [4.]],

        [[5.],
         [6.],
         [7.],
         [8.]]])
torch.Size([2, 4, 1])
3
##################################################
tensor([[[1., 2., 3., 4.]],

        [[5., 6., 7., 8.]]])
torch.Size([2, 1, 4])
3

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/782987.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号