In [130]:a=torch.rand(4,3,28,28) In [131]: a[0].shape 0ut[131]: torch.Size([3, 28, 28]) In [138]: a[0,0].shape 0ut[138]: torch.Size([28, 28]) In [139]: a[0,0,2,4] 0ut[139]: tensor(0. 8082)前/后N项
In [140]: a.shape 0ut[140]: torch.Size([4, 3, 28, 28]) In [141]: a[:2].shape//第0个维度从0到1 0ut[141]: torch.Size([2, 3, 28, 28]) In [142]: a[:2,:1,:,:].shape 0ut[142]: torch.Size([2,1, 28, 28]) In [143]: a[:2,1:,:,:].shape Out[143]: torch.Size([2, 2,28, 28]) In [144]: a[:2,-1:,:,:].shape//从最后一个元素到末尾 0ut[144]: torch.Size([2, 1, 28, 28])有间隔地选取
In [145]: a[:,:,0:28:2,0:28:2].shape//从0到28间隔2选择 Out[145]: torch.Size([4, 3, 14, 14]) In [146]: a[:,:,::2,::2].shape//从头到尾间隔step步长选择 Out[146]: torch.Size([4, 3,14, 14]) //通用形式:start:end:step按特定索引选取
index_select()
In [149]: a.shape 0ut[149]: torch.Size([4, 3, 28, 28]) In [159]: a.index_select(0, torch.sensor([0, 2])).shape 0ut[159]: torch.Size([2, 3, 28, 28]) In [159]: a.index_select(1, torch.sensor([1, 2])).shape 0ut[159]: torch.Size([4, 2, 28, 28]) In [168]: a.index_select(2, torch.arange(8)).shape 0ut[168]: torch.Size([4, 3, 8, 28]) //.index_select第二个参数不能以list的形式直接输入,需要以tensor的形式输入
…
In [149]: a.shape Out[149]: torch.Size([4, 3, 28, 28]) In [150]: a[...].shape 0ut[150]: torch.Size([4, 3, 28, 28]) //a[...] 相当于 a[0] 相当于 a[:,:,:,:] In [151]: a[0,...].shape 0ut[151]: torch.Size([3, 28, 28]) In [152]: a[:,1,...].shape 0ut[152]: torch.Size([4, 28, 28]) In [155]: a[..., :2].shape 0ut[155]: torch.Size([4, 3, 28, 2])按掩码选择
masked_select()
In [170]: x = torch.randn(3, 4) tensor([[-1.3911, -0.7871, -1.6558, -0.2542], [-0.9011, 0.5404, -0.6612, 0.3917], [-0.3854, 0.2968, 0.6040, 1.5771]]) In [172]: mask = x.ge(0.5) tensor([[0, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 1]], dtype=torch.uint8) In [174]: torch.masked_select(x, mask) 0ut[174]: tensor([0.5404, 0.6040, 1.5771]) //masked_select(x, mask)按照mask的掩码选择x中对应索引的元素 In [175]: torch.masked_select(x, mask).shape 0ut[175]. torch.Stze([3])



