tensorflow dataset.shuffle()之后需要带上.cache()

要注意,shuffle也是延迟执行的。若有代码

import tensorflow as tf

dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
dataset = dataset.shuffle(3) #.cache()
print(list(dataset))
print(list(dataset))

两次打印的顺序可能不同。如下

[<tf.Tensor: shape=(), dtype=int32, numpy=1>, <tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=2>]
[<tf.Tensor: shape=(), dtype=int32, numpy=2>, <tf.Tensor: shape=(), dtype=int32, numpy=3>, <tf.Tensor: shape=(), dtype=int32, numpy=1>]

修复方法是取消注释.cache()。

发表评论

电子邮件地址不会被公开。

:wink: :twisted: :roll: :oops: :mrgreen: :lol: :idea: :evil: :cry: :arrow: :?: :-| :-x :-o :-P :-D :-? :) :( :!: 8-O 8)