PyTorch中register_hook函数学习提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档
- 一、backward函数
- 二、register_hook函数
一、backward函数
当输出o不是标量时,不能直接o.backward(),需要向backward传入与输入x具有相同维度的tensor w,o.backward(w) 求的不是 o 对 x 的导数,而是 l = torch.sum(o*w)对 x 的导数,相当于多加了一步按权重线性求和,使得 o 变成了标量。
需要注意:当中间有变量时,如o=f(y),y=g(x),则该w同样作用于求y的梯度上。
import torch
def y_grad(grad):
print('y的梯度(z对y)为:', grad)
x = torch.tensor([1.,2.,3.], requires_grad=True)
y = torch.pow(x, 2)
z = x + y
y.register_hook(y_grad)
z.backward(torch.tensor([1,1,1]))
输出为:
y的梯度(z对y)为: tensor([1., 1., 1.])
y.register_hook(y_grad) z.backward(torch.tensor([1,2,1]))
输出为:
y的梯度(z对y)为: tensor([1., 2., 1.])
ps:requires_grad=False的变量可以输入进PyTorch的model,且修改变量requires_grad=True
二、register_hook函数由于反向传播时,不会保留中间变量的梯度,因此该函数的目的主要是对中间变量的梯度进行需要的操作
- register_hook(),该函数的参数必须为函数,调用方式为x.register_hook(func),将x的梯度作为参数传入func,func即可对x的梯度进行所需操作
- func对中间变量进行操作后,会改变该中间变量的梯度值,将改变的梯度值向后传播,影响叶子变量梯度
- 具体计算过程如下
import torch
def y_grad(grad):
print('y的梯度(z对y)为:', grad)
return grad**2
x = torch.tensor([1.,2.,3.], requires_grad=True)
y = torch.pow(x, 2)
z = x + y
y.register_hook(y_grad)
z.backward(torch.tensor([1,2,1]))
print(x.grad)
输出为:
y的梯度(z对y)为: tensor([1., 2., 1.])
tensor([ 3., 10., 7.])
计算推导:
z
=
y
+
x
=
x
2
+
x
z = y + x = x^2 + x
z=y+x=x2+x
此时
x
x
x,
y
y
y,
z
z
z,
w
w
w 都是vector,将
z
z
z 乘以
w
w
w 得到
z
z
z 为标量,
z
z
z对
x
x
x的导数为:
∂
z
∂
x
=
(
w
∂
z
∂
y
)
⋅
∂
y
∂
x
+
w
⋅
1
frac{partial z}{partial x} = (wfrac{partial z}{partial y}) cdot frac{partial y}{partial x} + w cdot 1
∂x∂z=(w∂y∂z)⋅∂x∂y+w⋅1
括号里的是 新的y的梯度,因此函数对y梯度的平方操作要包含w,即
(
w
∂
z
∂
y
)
2
(wfrac{partial z}{partial y})^2
(w∂y∂z)2 因此对于
x
[
1
]
=
2
x[1]=2
x[1]=2,对应的
w
=
2
w=2
w=2,
∂
z
∂
y
=
1
frac{partial z}{partial y}=1
∂y∂z=1,
∂
y
∂
x
=
2
x
frac{partial y}{partial x}=2x
∂x∂y=2x,新的
z
对
y
z对y
z对y的梯度为
(
2
⋅
1
)
=
2
(2cdot1)=2
(2⋅1)=2,经过平方后等于4,传到
x
[
1
]
x[1]
x[1]处时
∂
z
∂
x
[
1
]
=
(
w
∂
z
∂
y
)
2
⋅
∂
y
∂
x
+
w
⋅
1
=
(
2
⋅
1
)
2
⋅
2
⋅
2
+
2
=
18
frac{partial z}{partial x[1]} = (wfrac{partial z}{partial y})^2 cdot frac{partial y}{partial x} + w cdot 1=(2cdot1)^2cdot2cdot2+2=18
∂x[1]∂z=(w∂y∂z)2⋅∂x∂y+w⋅1=(2⋅1)2⋅2⋅2+2=18



