神经网络LSTM计算121. Best Time to Buy and Sell Stock
2022年2月13日原题:https://leetcode.com/problems/best-time-to-buy-and-sell-stock/、https://www.geeksforgeeks.org/maximum-difference-between-two-elements/
import random
import numpy as np
import tensorflow as tf
print(tf.__version__)
if tf.__version__ > '2.6':
import tensorflow.keras as keras
def maxDiff(arr):
max_diff = arr[1] - arr[0]
min_element = arr[0]
for i in range(1, len(arr)):
if arr[i] - min_element > max_diff:
max_diff = arr[i] - min_element
if arr[i] < min_element:
min_element = arr[i]
return max_diff
def generateTargetArray(targetMaxDiff, length, maxNum):
assert targetMaxDiff <= maxNum
arr = [0] * length
plannedMaxIndex = random.randint(1, length - 1)
min_element = maxNum
minIndex = 0
for i in range(0, len(arr)):
if i == plannedMaxIndex:
if min_element + targetMaxDiff <= maxNum:
arr[i] = min_element + targetMaxDiff
else:
d = min_element + targetMaxDiff - maxNum
t = np.array(arr)
t[minIndex:i] -= d
arr = t.tolist()
min_element -= d
arr[i] = min_element + targetMaxDiff
assert arr[i] <= maxNum
else:
arr[i] = random.randint(0, min(min_element + targetMaxDiff, maxNum))
if arr[i] < min_element:
min_element = arr[i]
minIndex = i
assert maxDiff(arr) == targetMaxDiff, arr
return arr
class ProfitModel(tf.keras.Model):
def __init__(self, *args, **kwargs):
# when return_sequences=True, for data in every time series, lstm returns the output.
# When return_sequences=False, only return the output of the last time series.
super().__init__(*args, **kwargs)
self.lstm = tf.keras.layers.LSTM(64, return_sequences=True, return_state=True)
self.dense = keras.layers.Dense(1, activation='sigmoid')
def call(self, inputs, training=None, mask=None):
whole_seq_output, final_memory_state, final_carry_state = self.lstm(inputs)
return self.dense(final_carry_state)
class RoundingMetric(tf.keras.metrics.Metric):
def __init__(self, maxNum, name='RoundingEqual'):
super(RoundingMetric, self).__init__(name=name)
self.maxNum = maxNum
self.matches = self.add_weight('matches', dtype=tf.int32)
self.total = self.add_weight('total', dtype=tf.int32)
def update_state(self, y_true, y_pred, sample_weight=None):
y_true = tf.math.rint(y_true * self.maxNum)
y_pred = tf.math.rint(y_pred * self.maxNum)
acc = tf.math.equal(y_true, y_pred)
self.matches.assign_add(tf.cast(tf.reduce_sum(tf.cast(acc, tf.float32)), tf.int32))
self.total.assign_add(len(y_true))
def result(self):
return self.matches / self.total
if __name__ == '__main__':
X = []
Y = []
maxNum = 20
for n in range(2000):
for d in range(maxNum + 1):
arr = generateTargetArray(d, 10, maxNum)
X.append(arr)
Y.append(d / maxNum)
X = tf.constant(X, dtype=tf.float32) / maxNum
X = tf.expand_dims(X, -1)
Y = tf.constant(Y)
size = X.shape[0]
validationSize = int(size * 0.2)
validationX = X[0:validationSize, ]
validationY = Y[0:validationSize, ]
trainingX = X[validationSize:, ]
trainingY = Y[validationSize:, ]
print(f'Training size is {trainingX.shape[0]}. Validation size is {validationX.shape[0]}.')
# whole_seq_output, final_memory_state, final_carry_state = lstm(sequences)
model = ProfitModel()
model.compile(optimizer='adam', loss='mse', metrics=[RoundingMetric(maxNum)])
checkpoint_filepath = f'saved-models/lstm'
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_RoundingEqual', mode='max',
save_best_only=True, verbose=0)
model.fit(X, Y, epochs=80, batch_size=128, validation_data=(validationX, validationY),
callbacks=[checkpoint_callback])
model.load_weights(checkpoint_filepath)
print('Best weights loaded')
testX = []
testY = []
for d in range(maxNum + 1):
arr = generateTargetArray(d, 10, maxNum)
testX.append(arr)
testY.append(d)
X = tf.constant(testX, dtype=tf.float32) / maxNum
X = tf.expand_dims(X, -1)
predicts = model.predict(X) * maxNum
for i in range(len(X)):
if testY[i] == round(predicts[i][0]):
print(f'{testX[i]}: actual is {testY[i]}, predicted is {predicts[i][0]}. PASS')
else:
print(f'{testX[i]}: actual is {testY[i]}, predicted is {predicts[i][0]}. FAIL')
原始算法是maxDiff。我先用generateTargetArray生成具有targetMaxDiff的数组。比如generateTargetArray(5,10,20)会生成
[18, 20, 2, 7, 0, 1, 3, 0, 2, 4] [20, 16, 20, 12, 11, 15, 16, 6, 3, 8] [15, 20, 19, 10, 12, 5, 6, 4, 6, 6] [18, 19, 12, 8, 3, 8, 5, 6, 5, 6]
等数组,它们的maxDiff都是5。借此生成数据集,并划分为训练集和验证集。
tensorflow模型为ProfitModel,采用subclass的方法,内部层为LSTM(RNN的一种)。LSTM的C状态为输出,再通过dense层,用sigmoid激活函数重整到[0,1]范围。模型训练好后,用全新生成的数据集测试。
经过80 epochs,准确率可以到达0.75。训练过程没有overfit的迹象,增加epochs和神经元数量是可行的。
2.6.0 Training size is 33600. Validation size is 8400. 2022-02-13 16:52:11.525719: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2) Epoch 1/80 329/329 [==============================] - 7s 11ms/step - loss: 0.0271 - RoundingEqual: 0.1402 - val_loss: 0.0213 - val_RoundingEqual: 0.1571 Epoch 21/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0029 - RoundingEqual: 0.3775 - val_loss: 0.0027 - val_RoundingEqual: 0.3776 Epoch 22/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0027 - RoundingEqual: 0.3911 - val_loss: 0.0026 - val_RoundingEqual: 0.3858 Epoch 23/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0025 - RoundingEqual: 0.4052 - val_loss: 0.0023 - val_RoundingEqual: 0.4180 Epoch 24/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0023 - RoundingEqual: 0.4164 - val_loss: 0.0022 - val_RoundingEqual: 0.3898 Epoch 25/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0022 - RoundingEqual: 0.4304 - val_loss: 0.0022 - val_RoundingEqual: 0.4170 Epoch 26/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0021 - RoundingEqual: 0.4395 - val_loss: 0.0020 - val_RoundingEqual: 0.4442 Epoch 27/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0019 - RoundingEqual: 0.4577 - val_loss: 0.0021 - val_RoundingEqual: 0.4068 Epoch 28/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0017 - RoundingEqual: 0.4707 - val_loss: 0.0018 - val_RoundingEqual: 0.4687 Epoch 29/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0016 - RoundingEqual: 0.4865 - val_loss: 0.0016 - val_RoundingEqual: 0.4956 Epoch 30/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0016 - RoundingEqual: 0.4960 - val_loss: 0.0015 - val_RoundingEqual: 0.5224 Epoch 31/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0015 - RoundingEqual: 0.4996 - val_loss: 0.0017 - val_RoundingEqual: 0.4715 Epoch 32/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0015 - RoundingEqual: 0.5159 - val_loss: 0.0014 - val_RoundingEqual: 0.5315 Epoch 33/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0014 - RoundingEqual: 0.5185 - val_loss: 0.0016 - val_RoundingEqual: 0.4849 Epoch 34/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0013 - RoundingEqual: 0.5312 - val_loss: 0.0013 - val_RoundingEqual: 0.5235 Epoch 35/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0013 - RoundingEqual: 0.5378 - val_loss: 0.0012 - val_RoundingEqual: 0.5468 Epoch 36/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0013 - RoundingEqual: 0.5440 - val_loss: 0.0012 - val_RoundingEqual: 0.5538 Epoch 37/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0012 - RoundingEqual: 0.5502 - val_loss: 0.0012 - val_RoundingEqual: 0.5526 Epoch 38/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0012 - RoundingEqual: 0.5576 - val_loss: 0.0012 - val_RoundingEqual: 0.5731 Epoch 39/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0011 - RoundingEqual: 0.5651 - val_loss: 0.0012 - val_RoundingEqual: 0.5452 Epoch 40/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0011 - RoundingEqual: 0.5710 - val_loss: 0.0012 - val_RoundingEqual: 0.5705 Epoch 41/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0011 - RoundingEqual: 0.5741 - val_loss: 0.0010 - val_RoundingEqual: 0.5973 Epoch 42/80 329/329 [==============================] - 2s 8ms/step - loss: 0.0011 - RoundingEqual: 0.5807 - val_loss: 9.9008e-04 - val_RoundingEqual: 0.6115 Epoch 43/80 329/329 [==============================] - 3s 8ms/step - loss: 0.0010 - RoundingEqual: 0.5907 - val_loss: 9.5682e-04 - val_RoundingEqual: 0.6182 Epoch 44/80 329/329 [==============================] - 3s 8ms/step - loss: 9.9314e-04 - RoundingEqual: 0.5973 - val_loss: 0.0011 - val_RoundingEqual: 0.5685 Epoch 45/80 329/329 [==============================] - 2s 8ms/step - loss: 9.9009e-04 - RoundingEqual: 0.5951 - val_loss: 0.0010 - val_RoundingEqual: 0.6017 Epoch 46/80 329/329 [==============================] - 3s 8ms/step - loss: 9.4755e-04 - RoundingEqual: 0.6051 - val_loss: 9.4081e-04 - val_RoundingEqual: 0.6105 Epoch 47/80 329/329 [==============================] - 3s 8ms/step - loss: 9.3701e-04 - RoundingEqual: 0.6098 - val_loss: 0.0010 - val_RoundingEqual: 0.5693 Epoch 48/80 329/329 [==============================] - 2s 8ms/step - loss: 9.0903e-04 - RoundingEqual: 0.6179 - val_loss: 9.2725e-04 - val_RoundingEqual: 0.6096 Epoch 49/80 329/329 [==============================] - 2s 8ms/step - loss: 8.8730e-04 - RoundingEqual: 0.6190 - val_loss: 8.4559e-04 - val_RoundingEqual: 0.6237 Epoch 50/80 329/329 [==============================] - 3s 8ms/step - loss: 8.7649e-04 - RoundingEqual: 0.6241 - val_loss: 8.8519e-04 - val_RoundingEqual: 0.6263 Epoch 51/80 329/329 [==============================] - 3s 8ms/step - loss: 8.4369e-04 - RoundingEqual: 0.6302 - val_loss: 8.1327e-04 - val_RoundingEqual: 0.6539 Epoch 52/80 329/329 [==============================] - 3s 8ms/step - loss: 8.3751e-04 - RoundingEqual: 0.6342 - val_loss: 8.1143e-04 - val_RoundingEqual: 0.6482 Epoch 53/80 329/329 [==============================] - 2s 8ms/step - loss: 8.0424e-04 - RoundingEqual: 0.6409 - val_loss: 7.7465e-04 - val_RoundingEqual: 0.6392 Epoch 54/80 329/329 [==============================] - 3s 8ms/step - loss: 7.8093e-04 - RoundingEqual: 0.6496 - val_loss: 7.9678e-04 - val_RoundingEqual: 0.6342 Epoch 55/80 329/329 [==============================] - 2s 8ms/step - loss: 7.7724e-04 - RoundingEqual: 0.6499 - val_loss: 7.7000e-04 - val_RoundingEqual: 0.6479 Epoch 56/80 329/329 [==============================] - 3s 8ms/step - loss: 7.4981e-04 - RoundingEqual: 0.6593 - val_loss: 9.3061e-04 - val_RoundingEqual: 0.6290 Epoch 57/80 329/329 [==============================] - 3s 8ms/step - loss: 7.5051e-04 - RoundingEqual: 0.6585 - val_loss: 8.4115e-04 - val_RoundingEqual: 0.6277 Epoch 58/80 329/329 [==============================] - 3s 8ms/step - loss: 7.3048e-04 - RoundingEqual: 0.6629 - val_loss: 7.7131e-04 - val_RoundingEqual: 0.6632 Epoch 59/80 329/329 [==============================] - 3s 8ms/step - loss: 7.1310e-04 - RoundingEqual: 0.6714 - val_loss: 7.7137e-04 - val_RoundingEqual: 0.6651 Epoch 60/80 329/329 [==============================] - 3s 8ms/step - loss: 6.9212e-04 - RoundingEqual: 0.6758 - val_loss: 7.0888e-04 - val_RoundingEqual: 0.6527 Epoch 61/80 329/329 [==============================] - 3s 8ms/step - loss: 6.8705e-04 - RoundingEqual: 0.6767 - val_loss: 6.7692e-04 - val_RoundingEqual: 0.6780 Epoch 62/80 329/329 [==============================] - 3s 8ms/step - loss: 6.6169e-04 - RoundingEqual: 0.6855 - val_loss: 7.4443e-04 - val_RoundingEqual: 0.6480 Epoch 63/80 329/329 [==============================] - 3s 8ms/step - loss: 6.5932e-04 - RoundingEqual: 0.6884 - val_loss: 6.1253e-04 - val_RoundingEqual: 0.7054 Epoch 64/80 329/329 [==============================] - 3s 8ms/step - loss: 6.3962e-04 - RoundingEqual: 0.6973 - val_loss: 6.6404e-04 - val_RoundingEqual: 0.6770 Epoch 65/80 329/329 [==============================] - 3s 8ms/step - loss: 6.2708e-04 - RoundingEqual: 0.7022 - val_loss: 6.5355e-04 - val_RoundingEqual: 0.7019 Epoch 66/80 329/329 [==============================] - 3s 8ms/step - loss: 6.1974e-04 - RoundingEqual: 0.7034 - val_loss: 6.1600e-04 - val_RoundingEqual: 0.6987 Epoch 67/80 329/329 [==============================] - 3s 8ms/step - loss: 6.0954e-04 - RoundingEqual: 0.7050 - val_loss: 5.8688e-04 - val_RoundingEqual: 0.7173 Epoch 68/80 329/329 [==============================] - 3s 8ms/step - loss: 5.9655e-04 - RoundingEqual: 0.7153 - val_loss: 6.0192e-04 - val_RoundingEqual: 0.7143 Epoch 69/80 329/329 [==============================] - 2s 8ms/step - loss: 5.9440e-04 - RoundingEqual: 0.7128 - val_loss: 6.0512e-04 - val_RoundingEqual: 0.7146 Epoch 70/80 329/329 [==============================] - 3s 8ms/step - loss: 5.7222e-04 - RoundingEqual: 0.7208 - val_loss: 7.2867e-04 - val_RoundingEqual: 0.6625 Epoch 71/80 329/329 [==============================] - 2s 8ms/step - loss: 5.5773e-04 - RoundingEqual: 0.7276 - val_loss: 5.1591e-04 - val_RoundingEqual: 0.7394 Epoch 72/80 329/329 [==============================] - 3s 8ms/step - loss: 5.5730e-04 - RoundingEqual: 0.7275 - val_loss: 5.4854e-04 - val_RoundingEqual: 0.7345 Epoch 73/80 329/329 [==============================] - 3s 8ms/step - loss: 5.4835e-04 - RoundingEqual: 0.7299 - val_loss: 5.7167e-04 - val_RoundingEqual: 0.7255 Epoch 74/80 329/329 [==============================] - 3s 8ms/step - loss: 5.5164e-04 - RoundingEqual: 0.7282 - val_loss: 4.9936e-04 - val_RoundingEqual: 0.7631 Epoch 75/80 329/329 [==============================] - 2s 8ms/step - loss: 5.2166e-04 - RoundingEqual: 0.7426 - val_loss: 5.7781e-04 - val_RoundingEqual: 0.7104 Epoch 76/80 329/329 [==============================] - 3s 8ms/step - loss: 5.2219e-04 - RoundingEqual: 0.7422 - val_loss: 4.7806e-04 - val_RoundingEqual: 0.7730 Epoch 77/80 329/329 [==============================] - 3s 8ms/step - loss: 5.0458e-04 - RoundingEqual: 0.7516 - val_loss: 4.8079e-04 - val_RoundingEqual: 0.7580 Epoch 78/80 329/329 [==============================] - 3s 8ms/step - loss: 5.0747e-04 - RoundingEqual: 0.7508 - val_loss: 5.1547e-04 - val_RoundingEqual: 0.7565 Epoch 79/80 329/329 [==============================] - 3s 9ms/step - loss: 4.9701e-04 - RoundingEqual: 0.7524 - val_loss: 5.1348e-04 - val_RoundingEqual: 0.7419 Epoch 80/80 329/329 [==============================] - 3s 8ms/step - loss: 4.8371e-04 - RoundingEqual: 0.7625 - val_loss: 4.3901e-04 - val_RoundingEqual: 0.7798 Best weights loaded [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]: actual is 0, predicted is 0.11382527649402618. PASS [11, 12, 1, 1, 1, 0, 1, 1, 1, 0]: actual is 1, predicted is 1.0460457801818848. PASS [14, 16, 13, 0, 0, 1, 2, 0, 0, 0]: actual is 2, predicted is 1.5106762647628784. PASS [3, 0, 0, 1, 3, 0, 0, 2, 3, 0]: actual is 3, predicted is 3.0876657962799072. PASS [7, 5, 9, 8, 6, 6, 9, 3, 4, 7]: actual is 4, predicted is 4.565701007843018. FAIL [20, 13, 18, 6, 3, 0, 4, 1, 0, 2]: actual is 5, predicted is 4.612735748291016. PASS [18, 6, 12, 3, 9, 4, 4, 0, 2, 2]: actual is 6, predicted is 6.360283374786377. PASS [14, 12, 7, 2, 9, 5, 3, 2, 6, 5]: actual is 7, predicted is 6.619134902954102. PASS [5, 8, 13, 7, 0, 0, 0, 8, 4, 8]: actual is 8, predicted is 8.578459739685059. FAIL [2, 8, 2, 11, 6, 6, 6, 8, 1, 2]: actual is 9, predicted is 8.897890090942383. PASS [5, 4, 4, 11, 14, 4, 1, 5, 6, 5]: actual is 10, predicted is 9.935867309570312. PASS [4, 15, 6, 4, 7, 8, 7, 12, 7, 14]: actual is 11, predicted is 11.337911605834961. PASS [8, 20, 7, 9, 1, 4, 12, 12, 9, 13]: actual is 12, predicted is 12.325309753417969. PASS [7, 11, 20, 1, 6, 2, 5, 14, 3, 4]: actual is 13, predicted is 13.197027206420898. PASS [6, 16, 10, 15, 13, 11, 20, 8, 12, 13]: actual is 14, predicted is 13.243918418884277. FAIL [5, 10, 7, 20, 20, 3, 7, 7, 12, 9]: actual is 15, predicted is 14.622146606445312. PASS [12, 0, 5, 8, 3, 8, 16, 7, 10, 5]: actual is 16, predicted is 16.36766815185547. PASS [18, 3, 10, 11, 7, 20, 6, 6, 11, 1]: actual is 17, predicted is 16.63459587097168. PASS [6, 14, 10, 0, 10, 2, 18, 10, 9, 11]: actual is 18, predicted is 18.820812225341797. FAIL [17, 17, 8, 10, 16, 13, 10, 0, 14, 19]: actual is 19, predicted is 18.509498596191406. PASS [17, 0, 12, 20, 20, 2, 10, 15, 3, 15]: actual is 20, predicted is 19.892822265625. PASS
参考 https://machinelearningmastery.com/learn-add-numbers-seq2seq-recurrent-neural-networks/