栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

Pytorch实验6:论文代码中的高级语法

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

Pytorch实验6:论文代码中的高级语法

整理了一下读GSnet代码时遇到的不懂的语法。

系列文章

  • Pytorch实验一:从零实现Logistic回归和Softmax回归
  • Pytorch实验2:手动实现前馈神经网络、Dropout、正则化、K折交叉验证,解决多分类、二分类、回归任务
  • Pytorch实验3:动手实现卷积神经网络、空洞卷积、残差网络
  • Pytorch实验4:手动/torch实现三种循环神经网络解决交通流量预测任务
  • Pytorch实验5:疫情微博文本情感分类 (简化版SMP2020赛题)
  • Pytorch实验6:论文代码中的高级语法

Pytorch中的高级语法整理
  • torch.tensor.repeat()
  • bmm (input, mat2)
  • tensor.permute(input, dims)
    • 关于permute(排列)函数的理解
    • permute / transpose + contiguous + view
  • torch.squeeze(input, dim=None)
  • torch.unsqueeze(input, dim=None)

torch.tensor.repeat()

repeat里面的参数 × times ×tensor原本的维度,从而得到repeat之后的维度(由内而外)。

if __name__ == '__main__':

    import torch
    from torch.autograd import Variable

    x = torch.tensor([0,4,8])
    y = torch.tensor([0,4,8,12])
    xx = x.repeat(4) # 行方向上复制4次
    print(xx)
    print(y.view(-1,1))
    yy = y.view(-1,1).repeat(1,3) # 行方向上复制3次,列方向上复制一次
    print(yy)

    yyy = y.view(-1, 1).repeat(3,1, 3) # 行方向复制3次,列方向不变,最外层复制3次
    print(yyy.shape)
    print(yyy)

bmm (input, mat2)

batch matrix multiplication

if __name__ == '__main__':
    import torch
    input = torch.randn(10, 3, 4)
    mat2 = torch.randn(10, 4, 5)
    res = torch.bmm(input, mat2)
    print(res.size())

结果:

tensor.permute(input, dims)

Returns a view of the original tensor inputwith its dimensionspermuted.

if __name__ == '__main__':
    import torch
    x = torch.tensor([[[1,2,3],[4,5,6]],[[1,2,3],[4,5,6]]])
    print(x)
    print(x.size())
    y = torch.permute(x, (2, 0, 1))
    print(y)
    print(y.size())

实验结果:

关于permute(排列)函数的理解
  • 比如图片img的size比如是(28,28,3)就可以利用img.permute(2,0,1)得到一个size为(3,28,28)的tensor。

  • 调用tensor.permute(1,0)意为将1轴(列轴)与0轴(行轴)调换,相当于进行转置

  • 可以理解为,对于一个高维的Tensor执行permute,我们没有改变数据的相对位置,而只是旋转了一下这个(超)立方体。或者也可以说,改变了我们对这个(超)立方体的“观察角度”而已。

  • 个人理解:怎么切西瓜?横着切还是竖着切。

  • 连续使用transpose也可实现permute的效果(连续选择任意两维进行转置,等价于多个维度的交换)

permute / transpose + contiguous + view

view只能作用在contiguous的variable上,如果在view之前调用了transpose、permute等,就需要调用contiguous()来返回一个contiguous copy;

判断ternsor是否为contiguous,可以调用torch.Tensor.is_contiguous()函数:

import torch 
x = torch.ones(10, 10) 
x.is_contiguous()                                 # True 
x.transpose(0, 1).is_contiguous()                 # False
x.transpose(0, 1).contiguous().is_contiguous()    # True

reshape(),相当于 tensor.contiguous().view(),这样就省去了对tensor做view()变换前,调用contiguous()的麻烦;换句话说,调用permute / transpose之后,可以用contiguous().view(),也可以用reshape().

from: https://zhuanlan.zhihu.com/p/76583143

torch.squeeze(input, dim=None)

Returns a tensor with all the dimensions of input of size 1 removed.

If dim is given, then the indexed dimension will be removed if it equals one.

Note that if the tensor has a batch dimension of size 1, then squeeze(input) will also remove the batch dimension, which can lead to unexpected errors.

import torch
x = torch.zeros(2, 1, 2, 1, 2)
y = torch.squeeze(x)
print(y.size())
y = torch.squeeze(x, 0)
print(y.size())
y = torch.squeeze(x, 1)
print(y.size())
torch.unsqueeze(input, dim=None)

Returns a new tensor with a dimension of size one inserted at the specified position.

A dimvalue within the range [-input.dim() - 1, input.dim() + 1) can be used. Negative dim will correspond to unsqueeze() applied at dim = dim + input.dim() + 1.

    import torch
    x = torch.tensor([1, 2, 3, 4])
    torch.unsqueeze(x, 0)
    print(x)
    print(x.shape)
    print(x.unsqueeze(0))
    print(x.unsqueeze(0).shape)
    print(x.unsqueeze(1))
    print(x.unsqueeze(1).shape)

(1,4)的区别就是[x1,x2,...,xn]外面多了一个[],变成了[[X]]

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/461598.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号