Use dataset in TensorFlow model training
2021年12月19日shape的要求
data=tf.data.Dataset.from_tensor_slices(([1,2,3],[1,2,3]))
这创建了一个二元组的dataset,即dataset里每个example都是一个二元组
print(list(data))
[[1,1],
[2,2],
[3,3]]
但它不可能是TensorFlow神经网络模型的输入。
TensorFlow神经网络模型要求fit的参数为dataset时,每个example都是一个二元组(或三元组),二元组的第一个元素必须为list,不能为scalar。
data=tf.data.Dataset.from_tensor_slices(([[1],[2],[3]],[1,2,3]))
assert len(dataset.element_spec) == 2, ‘Each example in dataset must be a 2-tuple.’
assert len(dataset.element_spec[0].shape)>0, ‘In each example, the first element itself must be a list.’
batch
Model.fit的参数为dataset时,必须(曾)被调用过Dataset.batch()。可以用以下方法检测
def isBatched(dataset: tf.data.Dataset):
x = dataset
while x is not None:
if x.__class__.__name__ == 'BatchDataset':
return True
else:
x = x._input_dataset
return False
if isBatched(data) == False:
print('data to fit is not batched!')
model.fit(data)