栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch常用API(2)

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch常用API(2)

学习内容:pytorch常用API(2) 1、张量的索引
a = torch.Tensor(2,3,32,32)
print(a[:,:,:,:].shape)#全要,一个冒号代表全要
print(a[0:1,:,:,:].shape)#取第一张图像,通道、宽度、高度全要
print(a[:,:,0:32:2,0:32:2].shape)#所有图像,通道全要,宽度高度全要但是每隔两个要一个进行一个下采样,变成了16*16

输出:
torch.Size([2, 3, 32, 32])
torch.Size([1, 3, 32, 32])
torch.Size([2, 3, 16, 16])

#pytorch索引
a = torch.linspace(1,12,steps=12)
a = a.view(3,4)
print(a)
b = torch.index_select(a,0,torch.tensor([0,2]))#0在行维度,取第0行和第2行
print(b)
c = torch.index_select(a,1,torch.tensor([1,3]))#1在列维度,取第1列和第3列
print(c)

输出:

2、torch.masked_select()
#torch.masked_select()
a = torch.randn(3,3)#随机三行三列
mask = torch.eye(3,3,dtype=torch.bool)
print(a)
print(mask)
c = torch.masked_select(a,mask)
print(c)#模板取

tensor([[ 0.6159, 0.6601, 2.2437],
[ 0.1718, -0.2925, -0.0763],
[-0.4906, 0.7846, -1.4077]])
tensor([[ True, False, False],
[False, True, False],
[False, False, True]])
tensor([ 0.6159, -0.2925, -1.4077])

3、torch.take()
#torch.take()
a = torch.randn(3,3)
b = torch.tensor([0,2,4,6])#0,2,4,6
c = torch.take(a,b)#先将a打平,再按0,2,4,6去索引
print(a)
print(b)
print(c)

tensor([[-0.2457, -0.8254, -0.2993],
[-0.1371, -0.7704, 0.5670],
[ 0.3647, 0.8135, -0.8036]])
tensor([0, 2, 4, 6])
tensor([-0.2457, -0.2993, -0.7704, 0.3647])

4、维度变化

permute(),可以同时换挪多个维度

#permute(),可以同时换挪多个维度
a = torch.rand(4,3,32,32)
b = a.permute(0,3,2,1)#可以同时换挪多个维度
print(a.shape)
print(b.shape)

torch.Size([4, 3, 32, 32])
torch.Size([4, 32, 32, 3])

view() reshape()这两个意思一样,可以变换维度,一般情况用reshape,鲁棒性更强一些

#view() reshape()这两个意思一样,可以变换维度,一般情况用reshape,鲁棒性更强一些
a = torch.rand(4,3,32,32)#四维
b = a.view(4,3,32*32)#变成三维,维度变换维数要相等(乘起来相等)
c = a.view(4,-1)#打成两维,剩余多少,用-1代替,自动会计算
print(a.shape)
print(b.shape)
print(c.shape)
d = a.reshape(4,3,8,4,8,4)#也可以升维
print(d.shape)

torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 1024])
torch.Size([4, 3072])
torch.Size([4, 3, 8, 4, 8, 4])

unsqueeze()扩张

#unsqueeze()扩张
a = torch.rand(4,3,32,32)
b = a.unsqueeze(0)#第0维进行扩张
c = a.unsqueeze(2)#第2维进行扩张
d = a.unsqueeze(4)
e = a.unsqueeze(-1)#负索引进行扩张
print(a.shape)
print(b.shape)
print(c.shape)
print(d.shape)
print(e.shape)

torch.Size([4, 3, 32, 32])
torch.Size([1, 4, 3, 32, 32])
torch.Size([4, 3, 1, 32, 32])
torch.Size([4, 3, 32, 32, 1])
torch.Size([4, 3, 32, 32, 1])

squeeze()压缩,只能压缩一维

#squeeze()压缩,只能压缩一维
a = torch.rand(1,1,32,32)
b = a.squeeze(0)#第0维进行压缩
c = a.squeeze(1)#第1维进行压缩
d = a.squeeze(3)#压不了,因为索引的维度不为1,但不会报错
e = a.squeeze(-1)#负索引
print(a.shape)
print(b.shape)
print(c.shape)
print(d.shape)
print(e.shape)

torch.Size([1, 1, 32, 32])
torch.Size([1, 32, 32])
torch.Size([1, 32, 32])
torch.Size([1, 1, 32, 32])
torch.Size([1, 1, 32, 32])

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/883434.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号