添加
@tf.function确实可以显着提高速度。看看这个:
import tensorflow as tfdata = tf.random.normal((1000, 10, 10, 1))dataset = tf.data.Dataset.from_tensors(data).batch(10)def iterate_1(dataset): for x in dataset: x = x@tf.functiondef iterate_2(dataset): for x in dataset: x = x%timeit -n 1000 iterate_1(dataset) # 1.46 ms ± 8.2 µs per loop%timeit -n 1000 iterate_2(dataset) # 239 µs ± 10.2 µs per loop
如您所见,使用进行迭代的
@tf.function速度提高了6倍以上。



