参考:
pytorch 矩阵维度 - 搜索结果 - 知乎
Pytorch 中的 dim操作介绍 - 大数据 - 亿速云
1.如何理解dim?
- pytorch的dim和numpy的axis很类似
- 不同dim的数据长什么样?
维度为0, 0维张量也叫标量 1 维度为1, 0维张量也叫矢量 [1,2] 维度为2, 0维张量也叫矩阵 [[1,2],[3,4]] 维度为3, 0维张量也叫矩阵数组 [[[1,2],[3,4]],[[1,2],[3,4]]]
二维矩阵a:
a = torch.tensor([[1, 2], [3, 4]])
print(a)
tensor([[1, 2],
[3, 4]])
解释:
三维张量b:
b = torch.tensor([[[3, 2], [1, 4]], [[5, 6], [7, 8]]])
print(b)
tensor([[[3, 2],
[1, 4]],
[[5, 6],
[7, 8]]])
解释:
2.在不同dim的计算
核心:在不同dim上的计算就是对这个dim中的元素的计算,以sum为例,计算b在不同维度的sum。
- dim=0
s = torch.sum(b, dim=0)
print(s)
tensor([[ 8, 8],
[ 8, 12]])
解释:
- dim=1
s = torch.sum(b, dim=1)
print(s)
tensor([[ 4, 6],
[12, 14]])
解释:
- dim=2
s = torch.sum(b, dim=2)
print(s)
tensor([[ 5, 5],
[11, 15]])
在 b 的第 2 维求和,就是对标量 3 和 2, 1 和 4, 5 和 6 , 7 和 8 求和
note:在进行计算时,结果的维度发生了变换,如果不想改变,需要keepdim=True



