栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch基础操作 —— 10. torch.dsplit 按深度分割张量

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch基础操作 —— 10. torch.dsplit 按深度分割张量

文章目录
  • torch.dsplit
  • 例程

torch.dsplit

dsplit 是另外一个和 chunk 相似的操作,但是和 chunk 有一定的不同。

   torch.dsplit(input, indices_or_sections) → List of Tensors

这里,最关键的就是 indices_or_sections 的描述内容,根据官网的介绍:

  • 根据 index_or_sections 将输入(具有三个或更多维度的张量)拆分为多个深度方向的张量。
  • 这等价于调用 torch.tensor_split(input,indices_or_sections,dim=2)(分割维度为 1),不同的是,如果indices_or_sections 是一个整数,则必须将分割维度平均分割,否则会抛出运行时错误。
  • 另外,这个函数是比较少见的没有dim这个参数,所以我们来看看它具体是怎么表现的。
例程

我们先来创建一个3维的张量。

>>> tensor = torch.arange(64).reshape(4, 4, 4)
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23],
         [24, 25, 26, 27],
         [28, 29, 30, 31]],

        [[32, 33, 34, 35],
         [36, 37, 38, 39],
         [40, 41, 42, 43],
         [44, 45, 46, 47]],

        [[48, 49, 50, 51],
         [52, 53, 54, 55],
         [56, 57, 58, 59],
         [60, 61, 62, 63]]])

然后

>>> torch.dsplit(tensor, 2)
(tensor([[[ 0,  1],
         [ 4,  5],
         [ 8,  9],
         [12, 13]],

        [[16, 17],
         [20, 21],
         [24, 25],
         [28, 29]],

        [[32, 33],
         [36, 37],
         [40, 41],
         [44, 45]],

        [[48, 49],
         [52, 53],
         [56, 57],
         [60, 61]]]),
tensor([[[ 2,  3],
         [ 6,  7],
         [10, 11],
         [14, 15]],

        [[18, 19],
         [22, 23],
         [26, 27],
         [30, 31]],

        [[34, 35],
         [38, 39],
         [42, 43],
         [46, 47]],

        [[50, 51],
         [54, 55],
         [58, 59],
         [62, 63]]]))
>>> tensors = torch.dsplit(tensor, 2)
>>> tensors[0].shape
torch.Size([4, 4, 2])

除此以外,由于它等价于 numpy.dsplit 函数,所以也支持适用元组、数组、一维张量来切片。我们看看对此是怎么定义的:

  • 如果 indices_or_sections 是整数n,或值为 n 的张量,则input分为 n 个部分。
  • 如果indices_or_sections 是整数的列表或元组,或一维长张量,则输入在列表、元组或张量中的每个元素位置处进行拆分。 例如,indices_or_sections=[2, 3] 将导致张量 input[:2]、input[2:3] 和 input[3:]。

所以,我们可以这样做:

>>> tensors = torch.dsplit(tensor, [2, 3]) # tensor[:2], tensor[2:3], tensor[3:]
>>> tensors[0].shape
torch.Size([4, 4, 2])

>>> tensors[1].shape
torch.Size([4, 4, 1])

>>> tensors[2].shape
torch.Size([4, 4, 1])

现在你明白怎么使用了吗?

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/339277.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号