栏目分类:
子分类:
返回
名师互学网用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
名师互学网 > IT > 软件开发 > 后端开发 > Python

【Tensorflow2.x学习笔记】tf.GradientTape自动求梯度

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

【Tensorflow2.x学习笔记】tf.GradientTape自动求梯度

tf.constant()函数

作用: 创建tensor常量
函数形式: tf.constant(value, shape, dtype=None, name=None)
参数释义: value:值,shape:数据形状,dtype:数据类型,name:名称

tf.GradientTape()函数

作用: 用于计算函数梯度,配合with as结构使用
函数形式: 使用__init__函数初始化对象,__enter__函数和__exit__函数
配合使用实现上下文管理器(用于连接需要计算梯度的函数与变量)

with as结构

作用: with as可用于简化try finally代码,与下述形式的try finally等价

try:  
    执行 __enter__的内容  
    执行 with_block.  
finally:  
    执行 __exit__内容 

执行过程: with expression as variable的执行过程是首先执行__enter__函数,它的返回值会赋值给as后面的variable,然后执行with-block中的语句,不论发生什么with-block执行后都执行__exit__函数

tape.watch()函数

作用: 确保某个tensor被tape追踪
函数形式: watch(tensor)
参数释义: tensor: 一个Tensor或者一个Tensor列表

注意:watch函数把需要计算梯度的变量加入。GradientTape默认只监控由 tf.Variable 创建的traiable=True属性(默认)的变量。若变量是constant,则计算梯度 需要增加 tape.watch([a, b, c])函数。当然,也可以设置不自动监控可训练变量,完全由自己指定,设置watch_accessed_variables=False就行了(一般用不到)。

tape.gradient()函数

作用: 根据tape上面的上下文来计算某个或者某些tensor的梯度
函数形式: gradient(target,sources,output_gradients=None,unconnected_gradients=tf.UnconnectedGradients.NONE)
参数释义: target:需要求导的目标函数方程、sources:被求导的一个Tensor或者Tensor列表

代码实践
# -*- coding : utf-8 -*-            
# @Time : 2022/3/4 23:02
# @Author : SXQ
# @FileName : autograd

import tensorflow as tf

# constant函数用于生成tensor常量
x = tf.constant(2.)
a = tf.constant(2.)
b = tf.constant(3.)
c = tf.constant(4.)

# with可用于简化try finally代码
# with expression as variable的执行过程是,首先执行__enter__函数
# 它的返回值会赋值给as后面的variable
# 然后执行with-block中的语句,不论发生什么with-block执行后都执行__exit__函数
with tf.GradientTape() as tape:
    # 确保某个tensor被tape追踪
    tape.watch([a, b, c])
    # 函数公式
    y = a ** 2 * x + b * x + c
# gradient函数根据tape上面的上下文来计算某个或者某些tensor的梯度
[dy_da, dy_db, dy_dc] = tape.gradient(y, [a, b, c])
print(dy_da, dy_db, dy_dc)
转载请注明:文章转载自 www.mshxw.com
本文地址:https://www.mshxw.com/it/754749.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 MSHXW.COM

ICP备案号:晋ICP备2021003244-6号