这样一个tensor 想用类似
x[np.where(x > 8)] = 8
x[np.where(x<3)] = 3
的形式将其改成
这种批量的条件判断改值,由于tensor不能直接用索引修改值
尝试了几种方法,
比如更改为Varient、numpy 等
会出现
TypeError: only integer scalar arrays can be converted to a scalar index
或者
TypeError: ‘ResourceVariable’ object does not support item assignment
等类似的错误
import tensorflow as tf tensor_input = tf.constant([i for i in range(20)], tf.float32) tensor_input = tf.reshape(tensor_input, [4, 5]) print(tensor_input) tensor_input = tf.where(tf.greater(tensor_input,8),8,tensor_input) tensor_input = tf.where(tf.less(tensor_input,3),3,tensor_input) print(tensor_input)
运行如下
凌晨找到的解决方案,原理就查api吧 权当抛砖引玉



