TensorFlow
2.0中没有内置机制可以覆盖作用域内的内置运算符的所有渐变。但是,如果您能够为每次呼叫内置操作员修改呼叫站点,则可以使用
tf.custom_gradient装饰器,如下所示:
@tf.custom_gradientdef custom_square(x): def grad(dy): return tf.constant(0.0) return tf.square(x), gradwith tf.Graph().as_default() as g: x = tf.Variable(5.0) with tf.GradientTape() as tape: s_2 = custom_square(x) with tf.compat.v1.Session() as sess: sess.run(tf.compat.v1.global_variables_initializer()) print(sess.run(tape.gradient(s_2, x)))



