- torch.cat
- 例程
- 低维度时的拼接
- 高维度时的拼接
- 验证
张量拼接是非常常见的操作,以OpenCV为例,有时候我们需要把彩色图片(通常为3通道数据)分别进行处理,然后再重新组合在一起,生成新的图片。对于类似的框架来说,也提供了类似的函数。
现在,让我们来看看Torch的张量拼接函数的原型:
torch.cat(tensors, dim=0, *, out=None) -> Tensor
- tensors,通常是一组张量,要求大小维度相同,否则会导致拼接失败。
- dim,是拼接的方向,默认是0.
这个函数本质上并不难理解,但是唯一比较麻烦的就是dim,也就是轴方向,这是来自Numpy的概念,我觉得对于这个概念最好的理解,还是直接看源码最好。
>>> import torch
>>> ones = torch.ones(4, 5)
>>> zeros = torch.zeros(4, 5)
>>> torch.cat((ones, zeros), dim=0)
tensor([[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.],
[0., 0., 0., 0., 0.]])
>>> torch.cat((ones, zeros), dim=1)
tensor([[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]])
>>> torch.cat((ones, zeros), dim=2)
Traceback (most recent call last):
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)
可以看到,对于低维度的拼接时,torch.cat 仅支持把两个张量按列、行方向进行拼接。那么对于高维度的拼接,比如三维度时,又是怎样的呢?
高维度时的拼接>>> ones = torch.ones(2, 3, 4, 5) >>> zeros = torch.zeros(2, 3, 4, 5) >>> cat1 = torch.cat((ones, zeros), dim=0) >>> cat1.shape torch.Size([4, 3, 4, 5]) >>> cat2 = torch.cat((ones, zeros), dim=1) >>> cat2.shape torch.Size([2, 6, 4, 5]) >>> cat3 = torch.cat((ones, zeros), dim=2) >>> cat3.shape torch.Size([2, 3, 8, 5]) >>> cat4 = torch.cat((ones, zeros), dim=3) >>> cat4.shape torch.Size([2, 3, 4, 10])
从这个例子可以很明显的看出,cat操作的dim,是依据张量围度从左往右进行计算的。所以很多网上所说的dim=0是沿着X轴、dim=1是沿着Y轴,dim=2是沿着Z轴这种说法是十分不准确的。
所以更准确的说法应该是:dim,指定操作沿着张量的第N位执行指令。 对于上面这个例子来说,执行cat操作时,dim=0,即指定以张量维度第0位,执行拼接操作。
验证为了证明结果,这里执行一个小程序片段,我们分别打印 cat1 和 cat4,看看同样的执行顺序会分别输出什么内容
dim0, dim1, dim2, dim3 = cat1.shape
for i in range(dim0):
print("i:", i, end=" ")
for j in range(dim1):
print("j:", j)
for k in range(dim2):
for l in range(dim3):
print(cat1[i, j, k, l].item(), end=" ")
print("")
print("")
print("-----------------------------")
dim0, dim1, dim2, dim3 = cat4.shape
for i in range(dim0):
print("i:", i, end=" ")
for j in range(dim1):
print("j:", j)
for k in range(dim2):
for l in range(dim3):
print(cat4[i, j, k, l].item(), end=" ")
print("")
print("")
i: 0 j: 0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
j: 1
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
j: 2
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
i: 1 j: 0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
j: 1
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
j: 2
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
1.0 1.0 1.0 1.0 1.0
i: 2 j: 0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
j: 1
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
j: 2
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
i: 3 j: 0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
j: 1
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
j: 2
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
0.0 0.0 0.0 0.0 0.0
-----------------------------
i: 0 j: 0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
j: 1
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
j: 2
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
i: 1 j: 0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
j: 1
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
j: 2
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0



