栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch的合并与分割

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch的合并与分割

一、合并 1. cat 函数
    规则:
      所合并的数据的dim一致要合并的维度上shape可以不一致,其余的shape必须一致理解:[class,student]合并=>[class,student],合并后的班级在含以上相同
    例子
a = torch.rand(4,3,16,32)
b = torch.rand(4,3,16,32)
# 第一参数,list,第二参数,再哪个维度合并
print(torch.cat([a,b],dim=2).shape)
# out: torch.Size([4, 3, 32, 32])
2.stack函数
    规则:
      要合并的两个维度必须一致会在合并的维度前插入一个新的维度理解:[class,student]合并=>[class_id,class,student],相当于合并后两个班级分开,意义上不同,比如dim=0维度上的班级是尖子班,dim=1维度上的班级是普通班。
    例子
print(' stack ')
print(torch.stack([a,b],dim=2).shape) # torch.Size([4, 3, 2, 16, 32]) ,在dim2插入一个新的维度

a = torch.rand(3,5)
b = torch.rand(3,5)
print(torch.stack([a,b],dim=0).shape)
# out : torch.Size([2, 3, 5])
二、分割 1. split函数
    规则:
      给定在某一维度拆分后长度给定在某一维度拆分后的每个长度
    例子
a = torch.rand(32,8)
b = torch.rand(32,8)
c = torch.stack([a,b],dim=0)
aa,bb = c.split([1,1],dim=0) # 给定要拆分的dim=0的shape
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
# aa.shape= torch.Size([1, 32, 8])  bb.shape= torch.Size([1, 32, 8])
aa,bb = c.split(1,dim=0) #在0维度差分成长度为1
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
# aa.shape= torch.Size([1, 32, 8])  bb.shape= torch.Size([1, 32, 8])

c = torch.rand(9,4)
aa,bb,cc = c.split(3,dim=0) # 拆分成3个size[3,4]
print(aa.shape,bb.shape,cc.shape)
# torch.Size([3, 4]) torch.Size([3, 4]) torch.Size([3, 4])
aa,bb,cc = c.split([2,3,4],dim=0)
print(aa.shape,bb.shape,cc.shape)
# torch.Size([2, 4]) torch.Size([3, 4]) torch.Size([4, 4])
2. chuck函数
    规则
      参数:要分割的数量(不足的向上取整),和维度
    例子
a = torch.rand(4,5,2)
aa,bb = a.chunk(2,dim=0)
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
# aa.shape= torch.Size([2, 5, 2])  bb.shape= torch.Size([2, 5, 2])
aa,bb = a.chunk(2,dim=1)
print('aa.shape=',aa.shape,' bb.shape=',bb.shape)
# aa.shape= torch.Size([4, 3, 2])  bb.shape= torch.Size([4, 2, 2])
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/725047.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号