栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

自动梯度autograd中的with torch.no

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

自动梯度autograd中的with torch.no

PyTorch: Tensors and autograd(auto-gradient)
因为pytorch中可以使用autograd实现神经网络反向传播过程的自动计算。 当我们使用autograd的时候,前向传播会定义一个计算图,图中的节点都是张量,图的边是函数,用于从输入张量产生输出张量。通过这个图的反向传播就可以轻松获得gradient。

虽然听起来很复杂,但是用起来是很简单的,每个张量都代表计算图中的一个节点。如果x是一个张量,并且对其设置x.requires_grad=True,那么x.grad会存储x相对于某个标量梯度的张量。

1. .requires_grad=True

默认情况下Tensor的requires_grad属性都是False,这样就不会保留其gradient。当其设置为x.requires_grad=True时,x上的操作就会被追踪(track)。并且反向传播时候会将gradient保存在.grad属性中。

import torch

# 默认情况下
x = torch.tensor([1.0, 2.0, 3.0])
print(x)		# 输出:tensor([1., 2., 3.])
x += 1
print(x)		# 输出:tensor([1., 2., 3.])

# 设置requires_grad=True
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
print(y)		# 输出:tensor([1., 2., 3.], requires_grad=True)
y = y + 1.0
print(y)		# 输出:tensor([2., 3., 4.], grad_fn=)
print(y.grad)

注意:

  • 对于叶子节点的张量不能执行+=之类的操作,会报错a leaf Variable that requires grad is being used in an in-place operation.
  • 在这里最后一行代码会有个警告UserWarning大致意思错误访问了非叶节点的.grid,我故意在这打印一下grad的,不用管它。如果你实在是想看一下,那可以在y = y + 1.0之后加上一句y.retain_grad(),之后就不会报错了,但是也只显示一个None。
2. Tensor的 .grad_fn属性

上边提到当你设置.requires_grad=True之后,该张量上的操作就会追踪。
可以看到上边代码y的输出和x是不一样的,第一次打印的时候后边显示requires_grad=True,第二次打印的时候后边有个grad_fn=,设置追踪之后,grad_fn属性会保存加在上边的操作。这里上边执行了x = x+1之后,就显示是进行了一个加的操作。
但是只有非叶节点才有该属性,叶子节点会显示None。

3. with torch.no_grad()

如果你不想某个操作被追踪,那你就需要使用with torch.no_grad():。
被该语句包裹起来的代码就不会被追踪gradient。

import torch

# 设置requires_grad=True
y = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
print(y)			# 1 tensor([1., 2., 3.], requires_grad=True)
y = y + 1.0
print(y)			# 2 tensor([2., 3., 4.], grad_fn=)
y = y*1.2
print(y)			# 3 tensor([2.4000, 3.6000, 4.8000], grad_fn=)
print(y.grad_fn)	# 4 

with torch.no_grad():
    y = y - 2
print(y)			# 5 tensor([0.4000, 1.6000, 2.8000])
print(y.grad_fn)	# 6 None
  • 第三个输出中grad_fn显示是追踪了一个乘法操作
  • 第四个输出单独打印一下grad_fn,再次显示是追踪了乘法操作
  • 第五个输出是在设定了with torch.no_grad():之后,发现输出中没有grad_fn属性了
  • 第六个输出单独打印一下grad_fn确实是不存在,即with torch.no_grad():包裹之下,那个乘法操作没有被追踪
4. .grad

上边举例的xy都是非叶节点,是不会存储.grad的,现在放到一个神经网络中试一下下。
我们用一个三节多项式 y = a + b x + c x 2 + d x 3 y = a + bx + cx^2 + dx^3 y=a+bx+cx2+dx3来拟合 sin ⁡ ( x ) sin(x) sin(x)只用一个简单的两层的神经网路。

代码如下:

import torch
import math

dtype = torch.float
device = torch.device("cpu")
# 取消下边这行的注释就可以在GPU上运行
# device = torch.device("cuda:0")

# Create Tensors to hold input and outputs.
# 默认情况下requires_grad=False, 表示我们不需要计算这个张量在反向传播过程中的梯度。
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
y = torch.sin(x)

# 随机初始化参数,设置requires_grad=True表示我们希望将反向传播过程中的梯度保留
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(2000):
    y_pred = a + b * x + c * x ** 2 + d * x ** 3

    # Now loss is a Tensor of shape (1,)
    # loss.item() 获取loss中的标量值
    loss = (y_pred - y).pow(2).sum()

    # 使用autogrid计算,这个调用会计算所有的requires_grad=True的张量的gradient。
    # 然后他们的值会分别存储在对应的张量中a.grad, b.grad. c.grad d.grad
    loss.backward()

    if t == 500:
        print(a.grad)				# tensor(-905.9598)
        print(a.grad_fn)			# None
        print(a.is_leaf)			# True
        print(y_pred.grad)			# None
        print(y_pred.grad_fn)		# 
        print(y_pred.is_leaf)		# False

    # 手动更新权重
    # Wrap in torch.no_grad()因为之前设置了requires_grad=True,但是我们不希望在autograd记录下a-操作的gradient
    with torch.no_grad():
        a -= learning_rate * a.grad
        b -= learning_rate * b.grad
        c -= learning_rate * c.grad
        d -= learning_rate * d.grad

        # 更新权重之后手动清楚存储梯度gradient的张量
        a.grad = None
        b.grad = None
        c.grad = None
        d.grad = None

上边代码我挑了循环中的一节打印了一下。is_leaf可以检验是否是叶子节点。
a是叶子节点,也可以看到存储的a.grad,但是grad_fn是None。
y不是叶子节点,打印y.grad会有警告,但是能看到grad_fn。
就是上边说的叶子节点通常为None,在这里只有结果节点的grad_fn才有效,用于指示梯度函数是哪种类型。


今天钻牛角尖貌似又成功了。(^-^)V

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/296164.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号