神经网络LSTM计算121. Best Time to Buy and Sell Stock

2022年2月13日

原题:https://leetcode.com/problems/best-time-to-buy-and-sell-stock/https://www.geeksforgeeks.org/maximum-difference-between-two-elements/


import random

import numpy as np
import tensorflow as tf

print(tf.__version__)

if tf.__version__ > '2.6':
	import tensorflow.keras as keras


def maxDiff(arr):
	max_diff = arr[1] - arr[0]
	min_element = arr[0]

	for i in range(1, len(arr)):
		if arr[i] - min_element > max_diff:
			max_diff = arr[i] - min_element

		if arr[i] < min_element:
			min_element = arr[i]
	return max_diff


def generateTargetArray(targetMaxDiff, length, maxNum):
	assert targetMaxDiff <= maxNum
	arr = [0] * length

	plannedMaxIndex = random.randint(1, length - 1)

	min_element = maxNum
	minIndex = 0
	for i in range(0, len(arr)):
		if i == plannedMaxIndex:
			if min_element + targetMaxDiff <= maxNum:
				arr[i] = min_element + targetMaxDiff
			else:
				d = min_element + targetMaxDiff - maxNum
				t = np.array(arr)
				t[minIndex:i] -= d
				arr = t.tolist()
				min_element -= d
				arr[i] = min_element + targetMaxDiff

			assert arr[i] <= maxNum
		else:
			arr[i] = random.randint(0, min(min_element + targetMaxDiff, maxNum))

		if arr[i] < min_element:
			min_element = arr[i]
			minIndex = i

	assert maxDiff(arr) == targetMaxDiff, arr
	return arr


class ProfitModel(tf.keras.Model):
	def __init__(self, *args, **kwargs):
		# when return_sequences=True, for data in every time series, lstm returns the output.
		# When return_sequences=False, only return the output of the last time series.
		super().__init__(*args, **kwargs)
		self.lstm = tf.keras.layers.LSTM(64, return_sequences=True, return_state=True)
		self.dense = keras.layers.Dense(1, activation='sigmoid')

	def call(self, inputs, training=None, mask=None):
		whole_seq_output, final_memory_state, final_carry_state = self.lstm(inputs)

		return self.dense(final_carry_state)


class RoundingMetric(tf.keras.metrics.Metric):
	def __init__(self, maxNum, name='RoundingEqual'):
		super(RoundingMetric, self).__init__(name=name)
		self.maxNum = maxNum

		self.matches = self.add_weight('matches', dtype=tf.int32)
		self.total = self.add_weight('total', dtype=tf.int32)

	def update_state(self, y_true, y_pred, sample_weight=None):
		y_true = tf.math.rint(y_true * self.maxNum)
		y_pred = tf.math.rint(y_pred * self.maxNum)

		acc = tf.math.equal(y_true, y_pred)
		self.matches.assign_add(tf.cast(tf.reduce_sum(tf.cast(acc, tf.float32)), tf.int32))
		self.total.assign_add(len(y_true))

	def result(self):
		return self.matches / self.total


if __name__ == '__main__':
	X = []
	Y = []
	maxNum = 20
	for n in range(2000):
		for d in range(maxNum + 1):
			arr = generateTargetArray(d, 10, maxNum)
			X.append(arr)
			Y.append(d / maxNum)

	X = tf.constant(X, dtype=tf.float32) / maxNum
	X = tf.expand_dims(X, -1)
	Y = tf.constant(Y)

	size = X.shape[0]
	validationSize = int(size * 0.2)
	validationX = X[0:validationSize, ]
	validationY = Y[0:validationSize, ]

	trainingX = X[validationSize:, ]
	trainingY = Y[validationSize:, ]

	print(f'Training size is {trainingX.shape[0]}. Validation size is {validationX.shape[0]}.')

	# whole_seq_output, final_memory_state, final_carry_state = lstm(sequences)

	model = ProfitModel()
	model.compile(optimizer='adam', loss='mse', metrics=[RoundingMetric(maxNum)])

	checkpoint_filepath = f'saved-models/lstm'
	checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
		filepath=checkpoint_filepath,
		save_weights_only=True,
		monitor='val_RoundingEqual', mode='max',
		save_best_only=True, verbose=0)

	model.fit(X, Y, epochs=80, batch_size=128, validation_data=(validationX, validationY),
			  callbacks=[checkpoint_callback])

	model.load_weights(checkpoint_filepath)
	print('Best weights loaded')

	testX = []
	testY = []
	for d in range(maxNum + 1):
		arr = generateTargetArray(d, 10, maxNum)
		testX.append(arr)
		testY.append(d)

	X = tf.constant(testX, dtype=tf.float32) / maxNum
	X = tf.expand_dims(X, -1)
	predicts = model.predict(X) * maxNum
	for i in range(len(X)):
		if testY[i] == round(predicts[i][0]):
			print(f'{testX[i]}: actual is {testY[i]}, predicted is {predicts[i][0]}. PASS')
		else:
			print(f'{testX[i]}: actual is {testY[i]}, predicted is {predicts[i][0]}. FAIL')

原始算法是maxDiff。我先用generateTargetArray生成具有targetMaxDiff的数组。比如generateTargetArray(5,10,20)会生成

[18, 20, 2, 7, 0, 1, 3, 0, 2, 4]
[20, 16, 20, 12, 11, 15, 16, 6, 3, 8]
[15, 20, 19, 10, 12, 5, 6, 4, 6, 6]
[18, 19, 12, 8, 3, 8, 5, 6, 5, 6]

等数组,它们的maxDiff都是5。借此生成数据集,并划分为训练集和验证集。

tensorflow模型为ProfitModel,采用subclass的方法,内部层为LSTM(RNN的一种)。LSTM的C状态为输出,再通过dense层,用sigmoid激活函数重整到[0,1]范围。模型训练好后,用全新生成的数据集测试。

经过80 epochs,准确率可以到达0.75。训练过程没有overfit的迹象,增加epochs和神经元数量是可行的。

2.6.0
Training size is 33600. Validation size is 8400.
2022-02-13 16:52:11.525719: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/80
329/329 [==============================] - 7s 11ms/step - loss: 0.0271 - RoundingEqual: 0.1402 - val_loss: 0.0213 - val_RoundingEqual: 0.1571
Epoch 21/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0029 - RoundingEqual: 0.3775 - val_loss: 0.0027 - val_RoundingEqual: 0.3776
Epoch 22/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0027 - RoundingEqual: 0.3911 - val_loss: 0.0026 - val_RoundingEqual: 0.3858
Epoch 23/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0025 - RoundingEqual: 0.4052 - val_loss: 0.0023 - val_RoundingEqual: 0.4180
Epoch 24/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0023 - RoundingEqual: 0.4164 - val_loss: 0.0022 - val_RoundingEqual: 0.3898
Epoch 25/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0022 - RoundingEqual: 0.4304 - val_loss: 0.0022 - val_RoundingEqual: 0.4170
Epoch 26/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0021 - RoundingEqual: 0.4395 - val_loss: 0.0020 - val_RoundingEqual: 0.4442
Epoch 27/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0019 - RoundingEqual: 0.4577 - val_loss: 0.0021 - val_RoundingEqual: 0.4068
Epoch 28/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0017 - RoundingEqual: 0.4707 - val_loss: 0.0018 - val_RoundingEqual: 0.4687
Epoch 29/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0016 - RoundingEqual: 0.4865 - val_loss: 0.0016 - val_RoundingEqual: 0.4956
Epoch 30/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0016 - RoundingEqual: 0.4960 - val_loss: 0.0015 - val_RoundingEqual: 0.5224
Epoch 31/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0015 - RoundingEqual: 0.4996 - val_loss: 0.0017 - val_RoundingEqual: 0.4715
Epoch 32/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0015 - RoundingEqual: 0.5159 - val_loss: 0.0014 - val_RoundingEqual: 0.5315
Epoch 33/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0014 - RoundingEqual: 0.5185 - val_loss: 0.0016 - val_RoundingEqual: 0.4849
Epoch 34/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0013 - RoundingEqual: 0.5312 - val_loss: 0.0013 - val_RoundingEqual: 0.5235
Epoch 35/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0013 - RoundingEqual: 0.5378 - val_loss: 0.0012 - val_RoundingEqual: 0.5468
Epoch 36/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0013 - RoundingEqual: 0.5440 - val_loss: 0.0012 - val_RoundingEqual: 0.5538
Epoch 37/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0012 - RoundingEqual: 0.5502 - val_loss: 0.0012 - val_RoundingEqual: 0.5526
Epoch 38/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0012 - RoundingEqual: 0.5576 - val_loss: 0.0012 - val_RoundingEqual: 0.5731
Epoch 39/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0011 - RoundingEqual: 0.5651 - val_loss: 0.0012 - val_RoundingEqual: 0.5452
Epoch 40/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0011 - RoundingEqual: 0.5710 - val_loss: 0.0012 - val_RoundingEqual: 0.5705
Epoch 41/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0011 - RoundingEqual: 0.5741 - val_loss: 0.0010 - val_RoundingEqual: 0.5973
Epoch 42/80
329/329 [==============================] - 2s 8ms/step - loss: 0.0011 - RoundingEqual: 0.5807 - val_loss: 9.9008e-04 - val_RoundingEqual: 0.6115
Epoch 43/80
329/329 [==============================] - 3s 8ms/step - loss: 0.0010 - RoundingEqual: 0.5907 - val_loss: 9.5682e-04 - val_RoundingEqual: 0.6182
Epoch 44/80
329/329 [==============================] - 3s 8ms/step - loss: 9.9314e-04 - RoundingEqual: 0.5973 - val_loss: 0.0011 - val_RoundingEqual: 0.5685
Epoch 45/80
329/329 [==============================] - 2s 8ms/step - loss: 9.9009e-04 - RoundingEqual: 0.5951 - val_loss: 0.0010 - val_RoundingEqual: 0.6017
Epoch 46/80
329/329 [==============================] - 3s 8ms/step - loss: 9.4755e-04 - RoundingEqual: 0.6051 - val_loss: 9.4081e-04 - val_RoundingEqual: 0.6105
Epoch 47/80
329/329 [==============================] - 3s 8ms/step - loss: 9.3701e-04 - RoundingEqual: 0.6098 - val_loss: 0.0010 - val_RoundingEqual: 0.5693
Epoch 48/80
329/329 [==============================] - 2s 8ms/step - loss: 9.0903e-04 - RoundingEqual: 0.6179 - val_loss: 9.2725e-04 - val_RoundingEqual: 0.6096
Epoch 49/80
329/329 [==============================] - 2s 8ms/step - loss: 8.8730e-04 - RoundingEqual: 0.6190 - val_loss: 8.4559e-04 - val_RoundingEqual: 0.6237
Epoch 50/80
329/329 [==============================] - 3s 8ms/step - loss: 8.7649e-04 - RoundingEqual: 0.6241 - val_loss: 8.8519e-04 - val_RoundingEqual: 0.6263
Epoch 51/80
329/329 [==============================] - 3s 8ms/step - loss: 8.4369e-04 - RoundingEqual: 0.6302 - val_loss: 8.1327e-04 - val_RoundingEqual: 0.6539
Epoch 52/80
329/329 [==============================] - 3s 8ms/step - loss: 8.3751e-04 - RoundingEqual: 0.6342 - val_loss: 8.1143e-04 - val_RoundingEqual: 0.6482
Epoch 53/80
329/329 [==============================] - 2s 8ms/step - loss: 8.0424e-04 - RoundingEqual: 0.6409 - val_loss: 7.7465e-04 - val_RoundingEqual: 0.6392
Epoch 54/80
329/329 [==============================] - 3s 8ms/step - loss: 7.8093e-04 - RoundingEqual: 0.6496 - val_loss: 7.9678e-04 - val_RoundingEqual: 0.6342
Epoch 55/80
329/329 [==============================] - 2s 8ms/step - loss: 7.7724e-04 - RoundingEqual: 0.6499 - val_loss: 7.7000e-04 - val_RoundingEqual: 0.6479
Epoch 56/80
329/329 [==============================] - 3s 8ms/step - loss: 7.4981e-04 - RoundingEqual: 0.6593 - val_loss: 9.3061e-04 - val_RoundingEqual: 0.6290
Epoch 57/80
329/329 [==============================] - 3s 8ms/step - loss: 7.5051e-04 - RoundingEqual: 0.6585 - val_loss: 8.4115e-04 - val_RoundingEqual: 0.6277
Epoch 58/80
329/329 [==============================] - 3s 8ms/step - loss: 7.3048e-04 - RoundingEqual: 0.6629 - val_loss: 7.7131e-04 - val_RoundingEqual: 0.6632
Epoch 59/80
329/329 [==============================] - 3s 8ms/step - loss: 7.1310e-04 - RoundingEqual: 0.6714 - val_loss: 7.7137e-04 - val_RoundingEqual: 0.6651
Epoch 60/80
329/329 [==============================] - 3s 8ms/step - loss: 6.9212e-04 - RoundingEqual: 0.6758 - val_loss: 7.0888e-04 - val_RoundingEqual: 0.6527
Epoch 61/80
329/329 [==============================] - 3s 8ms/step - loss: 6.8705e-04 - RoundingEqual: 0.6767 - val_loss: 6.7692e-04 - val_RoundingEqual: 0.6780
Epoch 62/80
329/329 [==============================] - 3s 8ms/step - loss: 6.6169e-04 - RoundingEqual: 0.6855 - val_loss: 7.4443e-04 - val_RoundingEqual: 0.6480
Epoch 63/80
329/329 [==============================] - 3s 8ms/step - loss: 6.5932e-04 - RoundingEqual: 0.6884 - val_loss: 6.1253e-04 - val_RoundingEqual: 0.7054
Epoch 64/80
329/329 [==============================] - 3s 8ms/step - loss: 6.3962e-04 - RoundingEqual: 0.6973 - val_loss: 6.6404e-04 - val_RoundingEqual: 0.6770
Epoch 65/80
329/329 [==============================] - 3s 8ms/step - loss: 6.2708e-04 - RoundingEqual: 0.7022 - val_loss: 6.5355e-04 - val_RoundingEqual: 0.7019
Epoch 66/80
329/329 [==============================] - 3s 8ms/step - loss: 6.1974e-04 - RoundingEqual: 0.7034 - val_loss: 6.1600e-04 - val_RoundingEqual: 0.6987
Epoch 67/80
329/329 [==============================] - 3s 8ms/step - loss: 6.0954e-04 - RoundingEqual: 0.7050 - val_loss: 5.8688e-04 - val_RoundingEqual: 0.7173
Epoch 68/80
329/329 [==============================] - 3s 8ms/step - loss: 5.9655e-04 - RoundingEqual: 0.7153 - val_loss: 6.0192e-04 - val_RoundingEqual: 0.7143
Epoch 69/80
329/329 [==============================] - 2s 8ms/step - loss: 5.9440e-04 - RoundingEqual: 0.7128 - val_loss: 6.0512e-04 - val_RoundingEqual: 0.7146
Epoch 70/80
329/329 [==============================] - 3s 8ms/step - loss: 5.7222e-04 - RoundingEqual: 0.7208 - val_loss: 7.2867e-04 - val_RoundingEqual: 0.6625
Epoch 71/80
329/329 [==============================] - 2s 8ms/step - loss: 5.5773e-04 - RoundingEqual: 0.7276 - val_loss: 5.1591e-04 - val_RoundingEqual: 0.7394
Epoch 72/80
329/329 [==============================] - 3s 8ms/step - loss: 5.5730e-04 - RoundingEqual: 0.7275 - val_loss: 5.4854e-04 - val_RoundingEqual: 0.7345
Epoch 73/80
329/329 [==============================] - 3s 8ms/step - loss: 5.4835e-04 - RoundingEqual: 0.7299 - val_loss: 5.7167e-04 - val_RoundingEqual: 0.7255
Epoch 74/80
329/329 [==============================] - 3s 8ms/step - loss: 5.5164e-04 - RoundingEqual: 0.7282 - val_loss: 4.9936e-04 - val_RoundingEqual: 0.7631
Epoch 75/80
329/329 [==============================] - 2s 8ms/step - loss: 5.2166e-04 - RoundingEqual: 0.7426 - val_loss: 5.7781e-04 - val_RoundingEqual: 0.7104
Epoch 76/80
329/329 [==============================] - 3s 8ms/step - loss: 5.2219e-04 - RoundingEqual: 0.7422 - val_loss: 4.7806e-04 - val_RoundingEqual: 0.7730
Epoch 77/80
329/329 [==============================] - 3s 8ms/step - loss: 5.0458e-04 - RoundingEqual: 0.7516 - val_loss: 4.8079e-04 - val_RoundingEqual: 0.7580
Epoch 78/80
329/329 [==============================] - 3s 8ms/step - loss: 5.0747e-04 - RoundingEqual: 0.7508 - val_loss: 5.1547e-04 - val_RoundingEqual: 0.7565
Epoch 79/80
329/329 [==============================] - 3s 9ms/step - loss: 4.9701e-04 - RoundingEqual: 0.7524 - val_loss: 5.1348e-04 - val_RoundingEqual: 0.7419
Epoch 80/80
329/329 [==============================] - 3s 8ms/step - loss: 4.8371e-04 - RoundingEqual: 0.7625 - val_loss: 4.3901e-04 - val_RoundingEqual: 0.7798
Best weights loaded
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]: actual is 0, predicted is 0.11382527649402618. PASS
[11, 12, 1, 1, 1, 0, 1, 1, 1, 0]: actual is 1, predicted is 1.0460457801818848. PASS
[14, 16, 13, 0, 0, 1, 2, 0, 0, 0]: actual is 2, predicted is 1.5106762647628784. PASS
[3, 0, 0, 1, 3, 0, 0, 2, 3, 0]: actual is 3, predicted is 3.0876657962799072. PASS
[7, 5, 9, 8, 6, 6, 9, 3, 4, 7]: actual is 4, predicted is 4.565701007843018. FAIL
[20, 13, 18, 6, 3, 0, 4, 1, 0, 2]: actual is 5, predicted is 4.612735748291016. PASS
[18, 6, 12, 3, 9, 4, 4, 0, 2, 2]: actual is 6, predicted is 6.360283374786377. PASS
[14, 12, 7, 2, 9, 5, 3, 2, 6, 5]: actual is 7, predicted is 6.619134902954102. PASS
[5, 8, 13, 7, 0, 0, 0, 8, 4, 8]: actual is 8, predicted is 8.578459739685059. FAIL
[2, 8, 2, 11, 6, 6, 6, 8, 1, 2]: actual is 9, predicted is 8.897890090942383. PASS
[5, 4, 4, 11, 14, 4, 1, 5, 6, 5]: actual is 10, predicted is 9.935867309570312. PASS
[4, 15, 6, 4, 7, 8, 7, 12, 7, 14]: actual is 11, predicted is 11.337911605834961. PASS
[8, 20, 7, 9, 1, 4, 12, 12, 9, 13]: actual is 12, predicted is 12.325309753417969. PASS
[7, 11, 20, 1, 6, 2, 5, 14, 3, 4]: actual is 13, predicted is 13.197027206420898. PASS
[6, 16, 10, 15, 13, 11, 20, 8, 12, 13]: actual is 14, predicted is 13.243918418884277. FAIL
[5, 10, 7, 20, 20, 3, 7, 7, 12, 9]: actual is 15, predicted is 14.622146606445312. PASS
[12, 0, 5, 8, 3, 8, 16, 7, 10, 5]: actual is 16, predicted is 16.36766815185547. PASS
[18, 3, 10, 11, 7, 20, 6, 6, 11, 1]: actual is 17, predicted is 16.63459587097168. PASS
[6, 14, 10, 0, 10, 2, 18, 10, 9, 11]: actual is 18, predicted is 18.820812225341797. FAIL
[17, 17, 8, 10, 16, 13, 10, 0, 14, 19]: actual is 19, predicted is 18.509498596191406. PASS
[17, 0, 12, 20, 20, 2, 10, 15, 3, 15]: actual is 20, predicted is 19.892822265625. PASS

参考 https://machinelearningmastery.com/learn-add-numbers-seq2seq-recurrent-neural-networks/