栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

PyTorch中register

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

PyTorch中register

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

PyTorch中register_hook函数学习
  • 一、backward函数
  • 二、register_hook函数


一、backward函数

当输出o不是标量时,不能直接o.backward(),需要向backward传入与输入x具有相同维度的tensor w,o.backward(w) 求的不是 o 对 x 的导数,而是 l = torch.sum(o*w)对 x 的导数,相当于多加了一步按权重线性求和,使得 o 变成了标量。

需要注意:当中间有变量时,如o=f(y),y=g(x),则该w同样作用于求y的梯度上。

import torch

def y_grad(grad):
    print('y的梯度(z对y)为:', grad)

x = torch.tensor([1.,2.,3.], requires_grad=True)
y = torch.pow(x, 2)
z = x + y

y.register_hook(y_grad)
z.backward(torch.tensor([1,1,1]))

输出为:

y的梯度(z对y)为: tensor([1., 1., 1.])

y.register_hook(y_grad)
z.backward(torch.tensor([1,2,1]))

输出为:

y的梯度(z对y)为: tensor([1., 2., 1.])

ps:requires_grad=False的变量可以输入进PyTorch的model,且修改变量requires_grad=True

二、register_hook函数

由于反向传播时,不会保留中间变量的梯度,因此该函数的目的主要是对中间变量的梯度进行需要的操作

  1. register_hook(),该函数的参数必须为函数,调用方式为x.register_hook(func),将x的梯度作为参数传入func,func即可对x的梯度进行所需操作
  2. func对中间变量进行操作后,会改变该中间变量的梯度值,将改变的梯度值向后传播,影响叶子变量梯度
  3. 具体计算过程如下
import torch

def y_grad(grad):
    print('y的梯度(z对y)为:', grad)
    return grad**2

x = torch.tensor([1.,2.,3.], requires_grad=True)
y = torch.pow(x, 2)
z = x + y

y.register_hook(y_grad)
z.backward(torch.tensor([1,2,1]))
print(x.grad)

输出为:

y的梯度(z对y)为: tensor([1., 2., 1.])
tensor([ 3., 10., 7.])

计算推导:
z = y + x = x 2 + x z = y + x = x^2 + x z=y+x=x2+x
此时 x x x, y y y, z z z, w w w 都是vector,将 z z z 乘以 w w w 得到 z z z 为标量, z z z对 x x x的导数为:
∂ z ∂ x = ( w ∂ z ∂ y ) ⋅ ∂ y ∂ x + w ⋅ 1 frac{partial z}{partial x} = (wfrac{partial z}{partial y}) cdot frac{partial y}{partial x} + w cdot 1 ∂x∂z​=(w∂y∂z​)⋅∂x∂y​+w⋅1
括号里的是 新的y的梯度,因此函数对y梯度的平方操作要包含w,即
( w ∂ z ∂ y ) 2 (wfrac{partial z}{partial y})^2 (w∂y∂z​)2 因此对于 x [ 1 ] = 2 x[1]=2 x[1]=2,对应的 w = 2 w=2 w=2, ∂ z ∂ y = 1 frac{partial z}{partial y}=1 ∂y∂z​=1, ∂ y ∂ x = 2 x frac{partial y}{partial x}=2x ∂x∂y​=2x,新的 z 对 y z对y z对y的梯度为 ( 2 ⋅ 1 ) = 2 (2cdot1)=2 (2⋅1)=2,经过平方后等于4,传到 x [ 1 ] x[1] x[1]处时 ∂ z ∂ x [ 1 ] = ( w ∂ z ∂ y ) 2 ⋅ ∂ y ∂ x + w ⋅ 1 = ( 2 ⋅ 1 ) 2 ⋅ 2 ⋅ 2 + 2 = 18 frac{partial z}{partial x[1]} = (wfrac{partial z}{partial y})^2 cdot frac{partial y}{partial x} + w cdot 1=(2cdot1)^2cdot2cdot2+2=18 ∂x[1]∂z​=(w∂y∂z​)2⋅∂x∂y​+w⋅1=(2⋅1)2⋅2⋅2+2=18

转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/861536.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号