多维度LSTM(长短期记忆)神经网络预测未来客户活期存款余额

'''
Created on 2020年10月26日
@author: 寻找手艺人
'''
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.preprocessing import MinMaxScaler
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import LSTM
from keras.layers import Dropout
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def train_data():
dataset_train = pd.read_excel('alive_data.xlsx', parse_dates=['date'])
training_set = dataset_train.iloc[:, 1:7].values
dataset_train.head()
sc = MinMaxScaler(feature_range=(0, 1))
training_set_scaled = sc.fit_transform(training_set)
X_train = []
y_train = []
for i in range(60, len(training_set_scaled)):
X_train.append(training_set_scaled[i - 60:i, :])
y_train.append(training_set_scaled[i, :])
X_train, y_train = np.array(X_train), np.array(y_train)
X_train = np.reshape(X_train, (X_train.shape[0], X_train.shape[1], 6))
return X_train, y_train, dataset_train, sc
def test_data(dataset_train, sc):
df = pd.read_excel('alive_data.xlsx', parse_dates=['date'])
dataset_test = df[(df['date'] >= '2020-10-10')].loc[:,['first6', 'birthday', 'last4', 'gender', 'age', 'alive_amt']]
real_stock_price = dataset_test['alive_amt'].values
dataset_total = pd.concat((dataset_train, dataset_test), axis=0)
db_all = dataset_total.iloc[:, 1:7].values
inputs = db_all[len(dataset_total) - len(dataset_test) - 60:]
inputs = inputs.reshape(-1, dataset_test.shape[1])
inputs = sc.transform(inputs)
X_test = []
for i in range(60, 60 + len(dataset_test)):
X_test.append(inputs[i - 60:i, :])
X_test = np.array(X_test)
X_test = np.reshape(X_test, (X_test.shape[0], X_test.shape[1], dataset_test.shape[1]))
return X_test, real_stock_price
def stock_model(X_train, y_train):
regressor = Sequential()
regressor.add(LSTM(units=50, return_sequences=True, input_shape=(X_train.shape[1], 6)))
regressor.add(Dropout(0.2))
regressor.add(LSTM(units=50, return_sequences=True))
regressor.add(Dropout(0.2))
regressor.add(LSTM(units=50, return_sequences=True))
regressor.add(Dropout(0.2))
regressor.add(LSTM(units=50))
regressor.add(Dropout(0.2))
regressor.add(Dense(units=6))
regressor.compile(optimizer='adam', loss='mean_squared_error')
regressor.fit(X_train, y_train, epochs=1000, batch_size=32)
return regressor
def main():
X_train, y_train, dataset_train, sc = train_data()
regressor = stock_model(X_train, y_train)
X_test, real_stock_price = test_data(dataset_train, sc)
predicted_stock_price = regressor.predict(X_test)
predicted_stock_price = sc.inverse_transform(predicted_stock_price)
plt.plot(real_stock_price, color='red', label='Real Alive amt')
plt.plot(predicted_stock_price[:, 5], color='green', label='Predicted Alive amt')
plt.title('Alive amt Prediction')
plt.xlabel('Time')
plt.ylabel('Alive amt Price')
plt.legend()
plt.show()
if __name__ == '__main__':
main()
requirements.txt
absl-py==0.10.0
asgiref==3.2.10
astor==0.8.1
astunparse==1.6.3
beautifulsoup4==4.9.1
bs4==0.0.1
cachetools==4.1.1
certifi==2020.6.20
chardet==3.0.4
cycler==0.10.0
Django==3.1.1
django-bootstrap4==2.2.0
et-xmlfile==1.0.1
gast==0.2.2
google-auth==1.22.1
google-auth-oauthlib==0.4.1
google-pasta==0.2.0
grpcio==1.33.1
h5py==2.10.0
idna==2.10
importlib-metadata==1.7.0
jdcal==1.4.1
Jinja2==2.11.2
joblib==0.16.0
Keras==2.3.1
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.2
kiwisolver==1.2.0
lxml==4.5.2
Markdown==3.3.2
MarkupSafe==1.1.1
matplotlib==3.3.2
mpl-finance==0.10.1
mplfinance==0.12.7a0
numpy==1.18.5
oauthlib==3.1.0
openpyxl==3.0.5
opt-einsum==3.3.0
pandas==1.1.3
Pillow==7.2.0
prettytable==0.7.2
protobuf==3.13.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyecharts==1.8.1
PyMySQL==0.10.1
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2020.1
PyYAML==5.3.1
requests==2.24.0
requests-oauthlib==1.3.0
rsa==4.6
scikit-learn==0.23.2
scipy==1.4.1
seaborn==0.11.0
sequential==1.0.0
simplejson==3.17.2
six==1.15.0
soupsieve==2.0.1
sqlparse==0.3.1
tensorboard==2.0.2
tensorboard-plugin-wit==1.7.0
tensorflow==2.0.0
tensorflow-estimator==2.0.1
termcolor==1.1.0
threadpoolctl==2.1.0
tushare==1.2.61
urllib3==1.25.10
websocket-client==0.57.0
Werkzeug==1.0.1
wrapt==1.12.1
xlrd==1.2.0
zipp==3.1.0