今天在想把一个三维的[3, x, y]的tensor转为[x, y, 3]遇到一些问题,最后的解决方法是把tensor转为numpy,然后使用numpy.transpose(mytensor, [1, 2, 0])解决。因此分析一下torch和numpy中的transpose函数。
参考numpy.transposetorch.transpose numpy.transpose
用法如下:
numpy.transpose(a, axes=None)
如果是一个二维矩阵,那么将返回该矩阵的转置。
a = np.random.randint(0, 10, (3, 2)) print(a, a.shape) a = np.transpose(a) print(a, a.shape) a = np.transpose(a, (1, 0)) print(a, a.shape)
结果为:
[[5 7] [2 3] [9 1]] (3, 2) [[5 2 9] [7 3 1]] (2, 3) [[5 7] [2 3] [9 1]] (3, 2)
可以看到,每次都是将该二维矩阵转置。
如果是多维矩阵:
a = np.random.randint(0, 10, (3, 2, 4)) print(a, a.shape) a = np.transpose(a, (1, 2, 0)) print(a, a.shape)
结果:
[[[4 3 6 8] [7 0 1 1]] [[0 2 6 4] [4 0 6 2]] [[3 3 4 6] [5 6 6 2]]] (3, 2, 4) [[[4 0 3] [3 2 3] [6 6 4] [8 4 6]] [[7 4 5] [0 0 6] [1 6 6] [1 2 2]]] (2, 4, 3)
可以看到,本来矩阵的形状是pre = [3, 2, 4],变换时传入的参数是(1, 2, 0),之后矩阵就变成了[2, 4, 3]也就是[pre[1], pre[2], pre[0]]。
但是具体的变换过程还没搞懂。
对于参数axes,默认是:range(a.ndim)[::-1],也就是原来矩阵shape的逆序。还可以这样用:
a = np.random.randint(0, 10, (3, 2, 4)) print(a, a.shape) b = a.transpose([1, 2, 0]) print(b, b.shape)
结果:
[[[5 0 8 0] [5 4 5 8]] [[6 8 8 8] [8 0 4 5]] [[2 6 8 3] [0 8 3 4]]] (3, 2, 4) [[[5 6 2] [0 8 6] [8 8 8] [0 8 3]] [[5 8 0] [4 0 8] [5 4 3] [8 5 4]]] (2, 4, 3)
效果是一样的。
不会创建一个新的对象a = np.random.randint(0, 10, (2, 4)) print(a) b = a.transpose() print(b)
结果:
[[6 5 4 8] [3 2 1 2]] [[6 3] [5 2] [4 1] [8 2]]
b[0][0] = 15 print(a) print(b)
结果:
[[15 5 4 8] [ 3 2 1 2]] [[15 3] [ 5 2] [ 4 1] [ 8 2]]
一个改变,另一个也改变。
torch.transpose用法如下:
torch.transpose(input, dim0, dim1)
返回一个tensor,是input的转置。并且同样是共享一个实际tensor,一个改变另一个也改变。
import torch a = torch.randint(0, 10, (2, 4)) print(a) b = torch.transpose(a, 1, 0) print(b) c = torch.transpose(a, 0, 1) print(c)
结果:
tensor([[2, 7, 0, 9],
[8, 2, 8, 7]])
tensor([[2, 8],
[7, 2],
[0, 8],
[9, 7]])
tensor([[2, 8],
[7, 2],
[0, 8],
[9, 7]])
可以看到,都是在进行矩阵的转置。
在本函数中,dim0和dim1会互换(转置),所以transpose(a, 1, 0)和transpose(a, 0, 1)效果一致,都是dim[0]和dim[1]互换。这与numpy的函数不同。
a = torch.randint(0, 10, (2, 3, 4)) print(a) b = torch.transpose(a, 1, 2) print(b)
结果:
tensor([[[2, 0, 6, 7],
[8, 8, 0, 2],
[6, 7, 6, 6]],
[[9, 1, 6, 4],
[8, 3, 2, 8],
[0, 0, 4, 9]]])
tensor([[[2, 8, 6],
[0, 8, 7],
[6, 0, 6],
[7, 2, 6]],
[[9, 8, 0],
[1, 3, 0],
[6, 2, 4],
[4, 8, 9]]])
但是,让一个多维矩阵某两维或者某几维的转换公式还不太清楚。


