栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch 学习笔记——索引与切片

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch 学习笔记——索引与切片

索引
In [130]:a=torch.rand(4,3,28,28)
In [131]: a[0].shape
0ut[131]: torch.Size([3, 28, 28])

In [138]: a[0,0].shape
0ut[138]: torch.Size([28, 28])

In [139]: a[0,0,2,4]
0ut[139]: tensor(0. 8082)
前/后N项
In [140]: a.shape
0ut[140]: torch.Size([4, 3, 28, 28])

In [141]: a[:2].shape//第0个维度从0到1
0ut[141]: torch.Size([2, 3, 28, 28])

In [142]: a[:2,:1,:,:].shape
0ut[142]: torch.Size([2,1, 28, 28])

In [143]: a[:2,1:,:,:].shape 
Out[143]: torch.Size([2, 2,28, 28])

In [144]: a[:2,-1:,:,:].shape//从最后一个元素到末尾
0ut[144]: torch.Size([2, 1, 28, 28])
有间隔地选取
In [145]: a[:,:,0:28:2,0:28:2].shape//从0到28间隔2选择
Out[145]: torch.Size([4, 3, 14, 14])

In [146]: a[:,:,::2,::2].shape//从头到尾间隔step步长选择
Out[146]: torch.Size([4, 3,14, 14])

//通用形式:start:end:step
按特定索引选取

index_select()

In [149]: a.shape
0ut[149]: torch.Size([4, 3, 28, 28])

In [159]: a.index_select(0, torch.sensor([0, 2])).shape
0ut[159]: torch.Size([2, 3, 28, 28])


In [159]: a.index_select(1, torch.sensor([1, 2])).shape
0ut[159]: torch.Size([4, 2, 28, 28])

In [168]: a.index_select(2, torch.arange(8)).shape
0ut[168]: torch.Size([4, 3, 8, 28])

//.index_select第二个参数不能以list的形式直接输入,需要以tensor的形式输入

In [149]: a.shape
Out[149]: torch.Size([4, 3, 28, 28])

In [150]: a[...].shape
0ut[150]: torch.Size([4, 3, 28, 28])
//a[...] 相当于 a[0] 相当于 a[:,:,:,:]

In [151]: a[0,...].shape
0ut[151]: torch.Size([3, 28, 28])

In [152]: a[:,1,...].shape
0ut[152]: torch.Size([4, 28, 28])

In [155]: a[..., :2].shape
0ut[155]: torch.Size([4, 3, 28, 2])
按掩码选择

masked_select()

In [170]: x = torch.randn(3, 4)
tensor([[-1.3911, -0.7871, -1.6558, -0.2542],
		[-0.9011, 0.5404, -0.6612, 0.3917],
		[-0.3854, 0.2968, 0.6040, 1.5771]])

In [172]: mask = x.ge(0.5)
tensor([[0, 0, 0, 0],
		[0, 1, 0, 0],
		[0, 0, 1, 1]], dtype=torch.uint8)

In [174]: torch.masked_select(x, mask)
0ut[174]: tensor([0.5404, 0.6040, 1.5771])
//masked_select(x, mask)按照mask的掩码选择x中对应索引的元素

In [175]: torch.masked_select(x, mask).shape
0ut[175]. torch.Stze([3])
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/878763.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号