- torch.dsplit
- 例程
dsplit 是另外一个和 chunk 相似的操作,但是和 chunk 有一定的不同。
torch.dsplit(input, indices_or_sections) → List of Tensors
这里,最关键的就是 indices_or_sections 的描述内容,根据官网的介绍:
- 根据 index_or_sections 将输入(具有三个或更多维度的张量)拆分为多个深度方向的张量。
- 这等价于调用 torch.tensor_split(input,indices_or_sections,dim=2)(分割维度为 1),不同的是,如果indices_or_sections 是一个整数,则必须将分割维度平均分割,否则会抛出运行时错误。
- 另外,这个函数是比较少见的没有dim这个参数,所以我们来看看它具体是怎么表现的。
我们先来创建一个3维的张量。
>>> tensor = torch.arange(64).reshape(4, 4, 4)
tensor([[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]],
[[16, 17, 18, 19],
[20, 21, 22, 23],
[24, 25, 26, 27],
[28, 29, 30, 31]],
[[32, 33, 34, 35],
[36, 37, 38, 39],
[40, 41, 42, 43],
[44, 45, 46, 47]],
[[48, 49, 50, 51],
[52, 53, 54, 55],
[56, 57, 58, 59],
[60, 61, 62, 63]]])
然后
>>> torch.dsplit(tensor, 2)
(tensor([[[ 0, 1],
[ 4, 5],
[ 8, 9],
[12, 13]],
[[16, 17],
[20, 21],
[24, 25],
[28, 29]],
[[32, 33],
[36, 37],
[40, 41],
[44, 45]],
[[48, 49],
[52, 53],
[56, 57],
[60, 61]]]),
tensor([[[ 2, 3],
[ 6, 7],
[10, 11],
[14, 15]],
[[18, 19],
[22, 23],
[26, 27],
[30, 31]],
[[34, 35],
[38, 39],
[42, 43],
[46, 47]],
[[50, 51],
[54, 55],
[58, 59],
[62, 63]]]))
>>> tensors = torch.dsplit(tensor, 2)
>>> tensors[0].shape
torch.Size([4, 4, 2])
除此以外,由于它等价于 numpy.dsplit 函数,所以也支持适用元组、数组、一维张量来切片。我们看看对此是怎么定义的:
- 如果 indices_or_sections 是整数n,或值为 n 的张量,则input分为 n 个部分。
- 如果indices_or_sections 是整数的列表或元组,或一维长张量,则输入在列表、元组或张量中的每个元素位置处进行拆分。 例如,indices_or_sections=[2, 3] 将导致张量 input[:2]、input[2:3] 和 input[3:]。
所以,我们可以这样做:
>>> tensors = torch.dsplit(tensor, [2, 3]) # tensor[:2], tensor[2:3], tensor[3:] >>> tensors[0].shape torch.Size([4, 4, 2]) >>> tensors[1].shape torch.Size([4, 4, 1]) >>> tensors[2].shape torch.Size([4, 4, 1])
现在你明白怎么使用了吗?



