Use dataset in TensorFlow model training

2021年12月19日

shape的要求

data=tf.data.Dataset.from_tensor_slices(([1,2,3],[1,2,3]))

这创建了一个二元组的dataset,即dataset里每个example都是一个二元组
print(list(data))
[[1,1],
[2,2],
[3,3]]

但它不可能是TensorFlow神经网络模型的输入。

TensorFlow神经网络模型要求fit的参数为dataset时,每个example都是一个二元组(或三元组),二元组的第一个元素必须为list,不能为scalar。

data=tf.data.Dataset.from_tensor_slices(([[1],[2],[3]],[1,2,3]))

assert len(dataset.element_spec) == 2, ‘Each example in dataset must be a 2-tuple.’
assert len(dataset.element_spec[0].shape)>0, ‘In each example, the first element itself must be a list.’

batch

Model.fit的参数为dataset时,必须(曾)被调用过Dataset.batch()。可以用以下方法检测

def isBatched(dataset: tf.data.Dataset):
	x = dataset
	while x is not None:
		if x.__class__.__name__ == 'BatchDataset':
			return True
		else:
			x = x._input_dataset
	return False


if isBatched(data) == False:
	print('data to fit is not batched!')
model.fit(data)