栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 面试经验 > 面试问答

为什么在Pytorch中对网络的权重进行复制时,它将在反向传播后自动更新?

面试问答 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

为什么在Pytorch中对网络的权重进行复制时,它将在反向传播后自动更新?

您必须

clone
使用参数,否则只需复制引用即可。

weights = []for param in model.parameters():    weights.append(param.clone())criterion = nn.BCELoss() # criterion and optimizer setupoptimizer = optim.Adam(model.parameters(), lr=0.001)foo = torch.randn(3, 10) # fake inputtarget = torch.randn(3, 5) # fake targetresult = model(foo) # predictions and comparison and backproploss = criterion(result, target)optimizer.zero_grad()loss.backward()optimizer.step()weights_after_backprop = [] # weights after backpropfor param in model.parameters():    weights_after_backprop.append(param.clone()) # only layer1's weight should update, layer2 is not usedfor i in zip(weights, weights_after_backprop):    print(torch.equal(i[0], i[1]))

这使

FalseFalseTrueTrue


转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/441170.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号