最近在学习分类网络,想去尝试用center_loss去提升分类效果。观看了这篇博客在Keras使用center-losss损失函数Keras自定义损失函数,感觉思路挺清晰,代码段也不复杂,就决定试一试。但当我将center_loss插入我的模型后出现了这样的问题:代码训练正常,但在读取已经训练好的模型进行分类测试时会报错。
ValueError: Variable centers already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope? Originally defined at:
File "D:anacandaenvsTFlibsite-packagestensorflow_corepythonframeworkops.py", line 1756, in __init__
self._traceback = tf_stack.extract_stack()
File "D:anacandaenvsTFlibsite-packagestensorflow_corepythonframeworkops.py", line 3322, in _create_op_internal
op_def=op_def)
File "D:anacandaenvsTFlibsite-packagestensorflow_corepythonframeworkop_def_library.py", line 742, in _apply_op_helper
attrs=attr_protos, op_def=op_def)
File "D:anacandaenvsTFlibsite-packagestensorflow_corepythonopsgen_state_ops.py", line 1527, in variable_v2
shared_name=shared_name, name=name)
File "D:anacandaenvsTFlibsite-packagestensorflow_corepythonopsstate_ops.py", line 79, in variable_op_v2
shared_name=shared_name)
尝试解决
然后我便开始了解决这一问题的艰辛经历。
首先,我看到网上的很多答案都是让在代码前加上这句话:
tf.reset_default_graph()
但是我加上之后会出现另一个错误
ValueError: Tensor("ArgMax:0", shape=(?,), dtype=int64) must be from the same graph as Tensor("metrics/categorical_accuracy/ArgMax:0", shape=(?,), dtype=int64).
需要删除之前添加的那句话才能解决。
因为代码中构建变量用的是 tf.get_variable ,可以按报错信息将参数reuse=True设置好。但是可惜的是我用的是TF2.1的版本,涉及到版本问题需要调用TF1.x的包,并提示我reuse这个参数不存在,也不行重新安装版本随即放弃。
后来在这篇博客TensorFlow : name_scope和variable_scope 区别分析里,找到了解决方法。将报错中的变量centers共享即可。
更改前的代码:
# 设置trainable=False是因为样本中心不是由梯度进行更新的
centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,initializer=tf.constant_initializer(0), trainable=False)
# 将label展开为一维的,如果labels已经是一维的,则该动作其实无必要
labels = tf.reshape(labels, [-1])
更改后的代码:
# 设置trainable=False是因为样本中心不是由梯度进行更新的
with tf.variable_scope('varScope', reuse=tf.AUTO_REUSE):
centers = tf.get_variable('centers', [num_classes, len_features], dtype=tf.float32,initializer=tf.constant_initializer(0), trainable=False)
# 将label展开为一维的,如果labels已经是一维的,则该动作其实无必要
labels = tf.reshape(labels, [-1])
然后就可以运行出结果啦,用了center_loss之后对于我的数据集而言分类的准确率会提高一点。
其实就是增加tf.variable_scope语句,然后将reuse参数设置为reuse=tf.AUTO_REUSE。
希望我的小小分享会对你有帮助



