栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

torch中的 inplace operation操作错误记录

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

torch中的 inplace operation操作错误记录

import torch
w = torch.rand(4, requires_grad=True)
w += 1
loss = w.sum()
loss.backward()

执行 loss 对参数 w 进行求导,会出现报错:RuntimeError: a leaf Variable that requires grad is being used in an in-place operation。是第 3 行代码 w += 1,如果把这句改成 w = w + 1,再执行就不会报错了。

import torch
x = torch.zeros(4)
w = torch.rand(4, requires_grad=True)
x[0] = torch.rand(1) * w[0]
for i in range(3):
    x[i+1] = torch.sin(x[i]) * w[i]
loss = x.sum()
loss.backward()
可使用 with torch.autograd.set_detect_anomaly(True) 定位具体的出错位置。

with torch.autograd.set_detect_anomaly(True):
    x = torch.zeros(4)
    w = torch.rand(4, requires_grad=True)
    x[0] = torch.rand(1) * w[0]
    for i in range(3):
        x[i+1] = torch.sin(x[i]) * w[i]
    loss = x.sum()
    loss.backward()

Error detected in SinBackward. 大概是 torch.sin() 函数出现了问题。将第 6 行代码 x[i+1] = torch.sin(x[i]) * w[i] 改成 x[i+1] = torch.sin(x[i].clone()) * w[i]

import torch
x = torch.zeros(4)
w = torch.rand(4, requires_grad=True)
x[0] = torch.rand(1) * w[0]
for i in range(3):
    x[i+1] = torch.sin(x[i].clone()) * w[i]
loss = x.sum()
loss.backward()

inplace operation 的报错:

x += 1 改成 x = x + 1;
x[:, :, 0:3] = x[:, :, 0:3] + 1 改成 x[:, :, 0:3] = x[:, :, 0:3].clone() + 1;
x[i+1] = torch.sin(x[i]) * w[i] 改成 x[i+1] = torch.sin(x[i].clone()) * w[i];
可使用 with torch.autograd.set_detect_anomaly(True) 帮助定位出错位置,一般会运行较长时间。

x = x + 1 is not in-place, because it takes the objects pointed to by x, creates a new Variable, adds 1 to x putting the result in the new Variable, and overwrites the object referenced by x to point to the new var. There are no in-place modifications, you only change Python references (you can check that id(x) is different before and after that line).

On the other hand, doing x += 1 or x[0] = 1 will modify the data of the Variable in-place, so that no copy is done. However some functions (in your case *) require the inputs to never change after they compute the output, or they wouldn’t be able to compute the gradient. That’s why an error is raised.


参考:https://blog.csdn.net/weixin_39679367/article/details/122754199

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/884487.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号