栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

nn.Linear

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

nn.Linear

torch.nn.Linear(in_features, out_features, bias=True, device=None, dtype=None)
这个函数主要是进行空间的线性映射

  • in_features:输入数据的数据维度
  • out_features:输出数据的数据维度
函数执行过程:

假设我们有一批数据 x x x, x x x的维度为20维,这一批数据一共有128个,我们要将20维的 x x x映射到30维空间的 y y y中,下面是计算过程,其中 w w w是Linear函数的weight权重

y = x W T + b y = xW^{T}+b y=xWT+b

其中 x = ( x 11 x 12 . . . x 1 , 20 x 21 x 22 . . . x 2 , 20 . . . . . . . . . . . . x 128 , 1 x 128 , 2 . . . x 128 , 20 ) 128 × 20 x=begin{pmatrix} x_{11} & x_{12} & ... & x_{1,20} \ x_{21} & x_{22} & ... & x_{2,20} \ ... & ... & ... & ... \ x_{128,1} & x_{128,2} & ... & x_{128,20} \ end{pmatrix}_{128times 20} x=⎝⎜⎜⎛​x11​x21​...x128,1​​x12​x22​...x128,2​​............​x1,20​x2,20​...x128,20​​⎠⎟⎟⎞​128×20​ w = ( w 11 w 12 . . . w 1 , 20 w 21 w 22 . . . w 2 , 20 . . . . . . . . . . . . w 30 , 1 w 30 , 2 . . . w 30 , 20 ) 30 × 20 w = begin{pmatrix} w_{11} & w_{12} & ... & w_{1,20} \ w_{21} & w_{22} & ... & w_{2,20} \ ... & ... & ... & ... \ w_{30,1} & w_{30,2} & ... & w_{30,20} \ end{pmatrix}_{30times 20} w=⎝⎜⎜⎛​w11​w21​...w30,1​​w12​w22​...w30,2​​............​w1,20​w2,20​...w30,20​​⎠⎟⎟⎞​30×20​

( x 11 x 12 . . . x 1 , 20 x 21 x 22 . . . x 2 , 20 . . . . . . . . . . . . x 128 , 1 x 128 , 2 . . . x 128 , 20 ) 128 × 20 ( w 11 w 21 . . . w 30 , 1 w 12 w 22 . . . w 30 , 2 . . . . . . . . . . . . w 1 , 20 w 2 , 20 . . . w 30 , 20 ) 20 × 30 = ( y 11 y 12 . . . y 1 , 30 y 12 y 22 . . . y 2 , 30 . . . . . . . . . . . . y 128 , 1 y 128 , 2 . . . y 128 , 30 ) 128 × 30 begin{pmatrix} x_{11} & x_{12} & ... & x_{1,20} \ x_{21} & x_{22} & ... & x_{2,20} \ ... & ... & ... & ... \ x_{128,1} & x_{128,2} & ... & x_{128,20} \ end{pmatrix}_{128times 20} begin{pmatrix} w_{11} & w_{21} & ... & w_{30,1} \ w_{12} & w_{22} & ... & w_{30,2} \ ... & ... & ... & ... \ w_{1,20} & w_{2,20} & ... & w_{30,20} \ end{pmatrix}_{20times 30} = begin{pmatrix} y_{11} & y_{12} & ... & y_{1,30} \ y_{12} & y_{22} & ... & y_{2,30} \ ... & ... & ... & ... \ y_{128,1} & y_{128,2} & ... & y_{128,30} \ end{pmatrix}_{128times 30} ⎝⎜⎜⎛​x11​x21​...x128,1​​x12​x22​...x128,2​​............​x1,20​x2,20​...x128,20​​⎠⎟⎟⎞​128×20​⎝⎜⎜⎛​w11​w12​...w1,20​​w21​w22​...w2,20​​............​w30,1​w30,2​...w30,20​​⎠⎟⎟⎞​20×30​=⎝⎜⎜⎛​y11​y12​...y128,1​​y12​y22​...y128,2​​............​y1,30​y2,30​...y128,30​​⎠⎟⎟⎞​128×30​

一个简单的例子
import torch


x = torch.randn(128, 20)  # 输入的维度是(128,20)
linear = torch.nn.Linear(20, 30)  # 20, 30是指维度
output = linear(x)

print('linear.weight.shape:   ', linear.weight.shape)
print('linear.bias.shape:     ', linear.bias.shape)
print('output.shape:          ', output.shape)

# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
# .t就是w转置之后的部分
ans = torch.mm(x, linear.weight.t()) + linear.bias
print('ans.shape:             ', ans.shape)
print(torch.equal(ans, output))


'''output:
linear.weight.shape:    torch.Size([30, 20])
linear.bias.shape:      torch.Size([30])
output.shape:           torch.Size([128, 30])
ans.shape:              torch.Size([128, 30])
True
'''
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/303615.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号