tensorflow dataset.shuffle()之后需要带上.cache()

2021年3月8日

要注意,shuffle也是延迟执行的。若有代码

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.shuffle(3) #.cache()
print(list(dataset))
print(list(dataset))

两次打印的顺序可能不同。如下

[<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=2>]
[<tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=1>]

修复方法是取消注释.cache()。