栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch基础操作 —— 7. torch.cat 拼接操作

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch基础操作 —— 7. torch.cat 拼接操作

文章目录
  • torch.cat
  • 例程
    • 低维度时的拼接
    • 高维度时的拼接
  • 验证

torch.cat

张量拼接是非常常见的操作,以OpenCV为例,有时候我们需要把彩色图片(通常为3通道数据)分别进行处理,然后再重新组合在一起,生成新的图片。对于类似的框架来说,也提供了类似的函数。

现在,让我们来看看Torch的张量拼接函数的原型:

    torch.cat(tensors, dim=0, *, out=None) -> Tensor
  • tensors,通常是一组张量,要求大小维度相同,否则会导致拼接失败。
  • dim,是拼接的方向,默认是0.
例程 低维度时的拼接

这个函数本质上并不难理解,但是唯一比较麻烦的就是dim,也就是轴方向,这是来自Numpy的概念,我觉得对于这个概念最好的理解,还是直接看源码最好。

>>> import torch
>>> ones = torch.ones(4, 5)
>>> zeros = torch.zeros(4, 5)

>>> torch.cat((ones, zeros), dim=0)
tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])
        
>>> torch.cat((ones, zeros), dim=1)
tensor([[1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
        [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.]])
        
>>> torch.cat((ones, zeros), dim=2)
Traceback (most recent call last):
IndexError: Dimension out of range (expected to be in range of [-2, 1], but got 2)        

可以看到,对于低维度的拼接时,torch.cat 仅支持把两个张量按列、行方向进行拼接。那么对于高维度的拼接,比如三维度时,又是怎样的呢?

高维度时的拼接
>>> ones = torch.ones(2, 3, 4, 5)
>>> zeros = torch.zeros(2, 3, 4, 5)

>>> cat1 = torch.cat((ones, zeros), dim=0)
>>> cat1.shape
torch.Size([4, 3, 4, 5])

>>> cat2 = torch.cat((ones, zeros), dim=1)
>>> cat2.shape
torch.Size([2, 6, 4, 5])

>>> cat3 = torch.cat((ones, zeros), dim=2)
>>> cat3.shape
torch.Size([2, 3, 8, 5])

>>> cat4 = torch.cat((ones, zeros), dim=3)
>>> cat4.shape
torch.Size([2, 3, 4, 10])

从这个例子可以很明显的看出,cat操作的dim,是依据张量围度从左往右进行计算的。所以很多网上所说的dim=0是沿着X轴、dim=1是沿着Y轴,dim=2是沿着Z轴这种说法是十分不准确的。

所以更准确的说法应该是:dim,指定操作沿着张量的第N位执行指令。 对于上面这个例子来说,执行cat操作时,dim=0,即指定以张量维度第0位,执行拼接操作。

验证

为了证明结果,这里执行一个小程序片段,我们分别打印 cat1 和 cat4,看看同样的执行顺序会分别输出什么内容

dim0, dim1, dim2, dim3 = cat1.shape
for i in range(dim0):
    print("i:", i, end=" ")
    for j in range(dim1):
        print("j:", j)
        for k in range(dim2):
            for l in range(dim3):
                print(cat1[i, j, k, l].item(), end=" ")
            print("")
        print("")

print("-----------------------------")

dim0, dim1, dim2, dim3 = cat4.shape
for i in range(dim0):
    print("i:", i, end=" ")
    for j in range(dim1):
        print("j:", j)
        for k in range(dim2):
            for l in range(dim3):
                print(cat4[i, j, k, l].item(), end=" ")
            print("")
        print("")

i: 0 j: 0
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

j: 1
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

j: 2
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

i: 1 j: 0
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

j: 1
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

j: 2
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 
1.0 1.0 1.0 1.0 1.0 

i: 2 j: 0
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

j: 1
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

j: 2
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

i: 3 j: 0
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

j: 1
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

j: 2
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 
0.0 0.0 0.0 0.0 0.0 

-----------------------------
i: 0 j: 0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

j: 1
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

j: 2
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

i: 1 j: 0
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

j: 1
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 

j: 2
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
1.0 1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/331495.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号