栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

pytorch的梯度传递

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

pytorch的梯度传递

pytorch的梯度传递
  • 1.requires_grad的传递
    • 1.1三种情况下的梯度传递
    • 1.2利用requires_grad=False冻结骨干网络
    • 1.3网络中的数据是记录梯度的

1.requires_grad的传递

requires_gard 是tensor的一个属性,requires_gard=False表示不记录梯度,requires_gard=True表示记录张量的梯度。

每次的计算抽象为张量 A 与 B 做数学运算得到张量 C,C 是否记录梯度取决于 A 和 B的情况。

1.1三种情况下的梯度传递
  • A.requires_gard=False B.requires_gard=True ⇒ C.requires_gard=True
    A = torch.tensor([1., 2., 3.], requires_grad=True)
    B = torch.tensor([4., 5., 6.], requires_grad=False)
    C = A + B
    C.requires_grad
    ---------------------------------------------------------------------------------
    True
    
  • A.requires_gard=True B.requires_gard=False ⇒ C.requires_gard=True
    A = torch.tensor([1., 2., 3.], requires_grad=False)
    B = torch.tensor([4., 5., 6.], requires_grad=True)
    C = A + B
    C.requires_grad
    ----------------------------------------------------------------------------------
    True
    
  • A.requires_gard=False B.requires_gard=False ⇒ C.requires_gard=False
    A = torch.tensor([1., 2., 3.], requires_grad=False)
    B = torch.tensor([4., 5., 6.], requires_grad=False)
    C = A + B
    C.requires_grad
    -----------------------------------------------------------------------------------
    False
    

由此可见,只有当输入都不需要记录梯度时,后续计算的张量才不记录梯度,只要有一个输入张量计算梯度,后续的张量均需要记录梯度

1.2利用requires_grad=False冻结骨干网络
# 获得pytorch的预训练模型
model = torchvision.models.resnet18(pretrained=True)
# 冻结model的梯度计算
for p in model.parameters():
    p.requires_grad = False
# 替换最上层的fc
model.fc = torch.nn.Linear(512, 100)
# 新创建的liner层默认requires_grad=True
optmizer = torch.optim.SGD(model.fc.parameters(), lr=0.001)
1.3网络中的数据是记录梯度的
model = torchvision.models.resnet18(pretrained=True)
inputs = torch.randn(1, 3, 128, 128)
inputs.requires_grad
model(inputs).requires_grad
--------------------------------------------------------------------
False
True

虽然输入网络的tensor inputs 是不记录梯度的(requires_grad=False),但是网络的参数记录梯度,导致中间层的输出数据和最终的输出数据的requires_grad=True。

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/875545.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号