- torch.column_stack
- 例程
- 2D
- 3D
- 4D
除了 cat 这个函数以外,torch也提供了其他的张量粘合方法。column_stack 是以列方向黏贴张量的方法,我们来看一看函数的定义。
torch.column_stack(tensors, *, out=None) → Tensor例程
官方的例程如下
>>> a = torch.tensor([1, 2, 3])
>>> b = torch.tensor([4, 5, 6])
>>> torch.column_stack((a, b))
tensor([[1, 4],
[2, 5],
[3, 6]])
>>> a = torch.arange(5)
>>> b = torch.arange(10).reshape(5, 2)
>>> torch.column_stack((a, b, b))
tensor([[0, 0, 1, 0, 1],
[1, 2, 3, 2, 3],
[2, 4, 5, 4, 5],
[3, 6, 7, 6, 7],
[4, 8, 9, 8, 9]])
不过我们要考虑一个复杂点的情况,就是对于高维的张量,它又表现为什么形式。
2D>>> a = torch.arange(0, 8).reshape(-1, 2) >>> b = torch.arange(0, 8).reshape(-1, 2) >>> c = torch.column_stack((a, b, a, b)) >>> a.shape torch.Size([4, 2]) >>> b.shape torch.Size([4, 2]) >>> c.shape torch.Size([4, 8])3D
>>> a = torch.arange(0, 8).reshape(-1, 2, 2) >>> b = torch.arange(0, 8).reshape(-1, 2, 2) >>> c = torch.column_stack((a, b, a, b)) >>> a.shape torch.Size([2, 2, 2]) >>> b.shape torch.Size([2, 2, 2]) >>> c.shape torch.Size([2, 8, 2])4D
>>> a = torch.arange(0, 8).reshape(1, 2, 2, 2) >>> b = torch.arange(0, 8).reshape(1, 2, 2, 2) >>> c = torch.column_stack((a, b, a, b)) >>> a.shape torch.Size([1, 2, 2, 2]) >>> b.shape torch.Size([1, 2, 2, 2]) >>> c.shape torch.Size([1, 8, 2, 2])
从上面这些例子可以看到,这个函数的其实和 torch.cat(*, dim=1) 很像,事实上是如果执行如下命令,也会得到TRUE的输出。
a = torch.arange(0, 8).reshape(2, 2, 2)
b = torch.arange(10, 18).reshape(2, 2, 2)
c1 = torch.cat((a, b, a, b), dim=1)
c2 = torch.column_stack((a, b, a, b))
if torch.equal(c1, c2):
print("TRUE")
else:
print("FALSE")



