tensorflow dataset.shuffle()之后需要带上.cache()
2021年3月8日要注意,shuffle也是延迟执行的。若有代码
import tensorflow as tf
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.shuffle(3) #.cache()
print(list(dataset))
print(list(dataset))
两次打印的顺序可能不同。如下
[<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=2>] [<tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=1>]
修复方法是取消注释.cache()。