在情况下
tf.where,您有一个具有三个输入的函数,即condition
C,true
T值和false值
F以及一个输出
Out。渐变接收一个值,并且必须返回三个值。当前,没有为该条件计算梯度(这几乎没有道理),因此您只需要为
T和进行梯度
F。假设输入和输出是向量,可以想象
C[0]是
True。然后
Out[0]来自
T[0],其梯度应传播回去。另一方面,
F[0]将被丢弃,因此其梯度应设为零。如果
Out[1]是
False,则的梯度
F[1]应该传播,而不是
T[1]。简而言之,
T你应该传播给定的梯度这里
C是
True并使其零它在哪里
False,而相对来说
F。如果看一下(operation)的梯度的实现
tf.where``Select,它确实做到了:
@ops.RegisterGradient("Select")def _SelectGrad(op, grad): c = op.inputs[0] x = op.inputs[1] zeros = array_ops.zeros_like(x) return (None, array_ops.where(c, grad, zeros), array_ops.where( c, zeros, grad))请注意,输入值本身未在计算中使用,这将通过产生这些输入的操作的梯度来完成。对于
tf.cond,代码有点复杂,因为
Merge在不同的上下文中使用了相同的操作(),并且
tf.cond还使用了
Switch内部的操作。但是,想法是相同的。本质上,
Switch操作用于每个输入,因此被激活的输入(第一个(如果条件是条件
True,第二个否则)获得接收到的梯度,另一个输入得到“关闭”梯度(如
None),并且不会传播再往回走。



