tl; dr 为避免这种情况,请将您的输入投射到
float32
X = tf.cast(iris[:, :3], tf.float32) y = tf.cast(iris[:, 3], tf.float32)
或搭配
numpy:
X = np.array(iris[:, :3], dtype=np.float32)y = np.array(iris[:, 3], dtype=np.float32)
说明
默认情况下,Tensorflow使用
floatx,默认为
float32,这是深度学习的标准。您可以验证以下内容:
import tensorflow as tftf.keras.backend.floatx()Out[3]: 'float32'
您提供的输入(虹膜数据集)的输入类型为dtype
float64,因此Tensorflow的默认权重dtype与输入之间不匹配。Tensorflow不喜欢这样,因为强制转换(更改dtype)的成本很高。操作不同dtype的张量时(例如,比较
float32logit和
float64标签),Tensorflow通常会引发错误。
它所谈论的“新行为”:
图层my_model_1正在将dtype float64的输入张量转换为该图层的float32的dtype,这是TensorFlow 2中的新行为
是它将自动将输入dtype强制转换为
float32。在这种情况下,Tensorflow 1.X可能引发了异常,尽管我不能说我曾经使用过它。



