RNN神经网络 python

本文介绍了如何使用Python库pandas、numpy进行数据预处理,包括读取CSV文件、数据归一化,然后使用Keras构建SimpleRNN模型对股票价格进行预测。通过可视化展示训练结果和测试数据的预测性能。
摘要由CSDN通过智能技术生成
import pandas as pd
import numpy as np
data=pd.read_cvs('data.cvs')
data.head()

price=data.loc[:,'close']
#归一化处理
price_norm=price/max(price)
//可视化
%matplotlib inline
from matplotilb import pyplot as plt
fig=plt.figure(figsize=(8,5))
plt.pot(price)
plt.title('data vs price')
plt.xlable('time')
plt.ylable('price')
plt.show()

//x,y赋值
 def extract_data(data,time_step):
        x=[]
        y=[]
        for i in range(len(data)-time_step)
            x.append([a for a in data[i:i+time_step]])
            y.append(data[i+time_step])
        x=np.array(x)
        x=x.reshape(x.shape[0],x.shape[1],1)
        return x,y
//数据提取
x,y=extract_data(price_norm,time_step)

//set up the model
from keras.models import Sequential
from keras.layers import Dense,SimpleRNN
model=()
model.add(SimpleRNN(units=5,input_shape=(time_step,1),activation='relu'))
model.add(Dense(units=1,actication='linear'))
model.comple(optimizer='adam',loss='mean_squared_error')
model.summary()

//模型训练
model.fit(x,y,batch_size=30,epochs=200)

//预测
y_train_price=model.preict(x)*max(price)
y_train=[i*max(price) for i in y]


fig1=plt.figure(figsize=(8,5))
plt.pot(y_train_price)
plt.pot(y_train)
plt.title('data vs price')
plt.xlable('time')
plt.ylable('price')
plt.show()

//测试数据
data_test=pd.read_csv('data_test.csv')
data_test.head()
price_test=data_test.loc[:,'close']
price_test_norm=price_test/max(price)
//提取数据
x_test_norm,y_test_norm=extract_data(price_test_norm,time_step)
y_test_price=model.preict(x_test_norm)*max(price)
y_test=[i* max(price) for i in  y_test_norm]

fig2=plt.figure(figsize=(8,5))
plt.pot(y_test_price,lable='y_test_price')
plt.pot(y_test,lable='y_test')
plt.title('data vs price')
plt.xlable('time')
plt.ylable('price')
plt.show()

//存储数据
result_y_test=np.array(y_test).reshape(-1,1)
result_y_test_predict=y_test_predict
result=np.concatenate((result_y_test,result_y_test_predict),axis=1)
result=pd.DataFrame(result,columns=['real_price_test','predict_price_test'])
result.to_csv("zz.csv")

  • 10
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值