以下代码来源于LSTM预测客流量
# -*- coding:utf-8 -*-
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense, LSTM
data = pd.read_csv('./data/单月每日客流1.csv')
data['time'] = pd.to_datetime(data['time'])
data = data.set_index('time')
# 画出趋势图
def get_picture(data=data):
data['passengers'].plot()
plt.figure(figsize=(100, 50))
plt.show()
# 转化序列
def processing(data=data, long=11):
"""
依次转化为11列
"""
data['passengers'] = data['passengers'].astype(float)
sample = len(data) - long + 1
print('得到{}个样本'.format(sample))
data_sample = []
for i in range(sample):
data_sample.append(data['passengers'][i:i + 11])
data_sample = np.array(data_sample)
return data_sample
# 训练LSTM网络
def lstm(input_data=None):
scaler_x = MinMaxScaler()
scaler_y = MinMaxScaler()
x = input_data[:, :-1]
# print("x :" ,x )
# print(x.shape)
y = input_data[:, -1]
# print("y :", y)
# print(y.shape)
x = scaler_x.fit_transform(x) # 自变量和因变量分别归一化
y = scaler_y.fit_transform(np.reshape(y, (len(y), 1))) #
spilt = int(len(y) * 0.8)
x_train = x[:spilt]
# print("x_train:",x_train)
# print("x_train.shape:", x_train.shape)
x_test = x[spilt:]
# print("x_test:",x_test)
# print("x_test.shape:",x_test.shape)
y_train = y[:spilt]
y_test = y[spilt:]
x_train = np.reshape(x_train, (x_train.shape[0], x_train.shape[1], 1))
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))
# print(x_train.shape,y_train.shape,x_test.shape,y_test.shape)
# print(x_train[0],y_train[0],x_test[0],y_test[0])
model = Sequential()
model.add(LSTM(50, input_shape=(x_train.shape[1], 1), return_sequences=True))
model.add(LSTM(100))
model.add(Dense(1, activation='linear'))
model.compile(loss='mse', optimizer='rmsprop')
print('Train...')
model.fit(x_train, y_train, batch_size=8, epochs=300, validation_split=0.1)
predict = model.predict(x_test)
y_test = scaler_y.inverse_transform(np.reshape(y_test, (len(y_test), 1)))
predict = scaler_y.inverse_transform(predict)
# print(type(predict))
# print(predict)
plt.plot(predict, 'g:')
plt.plot(y_test, 'r-')
plt.show()
# model.save('lstm.h5')
if __name__ == '__main__':
# get_picture()
sam = processing()
lstm(sam)
数据格式
本人觉得以下这篇文章对LSTM的使用写得挺好的,推荐大家看下。
LSTM实战