# 定义模型参数
class NetConfig():
def __init__(self):
self.rnn_unit = 10
self.input_size = 7
self.output_size = 1
self.lr = 0.0006
self.time_step = 20
self.batch_size = 80
self.weights = {
"in":tf.Variable(tf.random_normal([self.input_size,self.rnn_unit])),
"out":tf.Variable(tf.random_normal([self.rnn_unit,self.output_size]))
}
self.biases = {
"in":tf.Variable(tf.constant(0.1,shape=[self.rnn_unit,])),
"out":tf.Variable(tf.constant(0.1,shape=[self.output_size,]))
}
# 定义训练模型
class MyModel():
def __init__(self):
self.sess = tf.Session()
self.nc = NetConfig()
def train_rnn(self,data,save_path,iter_num):
weights = self.nc.weights
biases = self.nc.biases
input_size = self.nc.input_size
rnn_unit = self.nc.rnn_unit
output_size = self.nc.output_size
time_step = self.nc.time_step
batch_size = self.nc.batch_size
lr = self.nc.lr
sess = self.sess
X = tf.placeholder(tf.float32,shape=
RNN预测股票数据---模型和参数
最新推荐文章于 2024-08-19 10:12:10 发布
本文探讨了如何利用循环神经网络(RNN)进行股票市场预测,详细介绍了模型构建过程及关键参数设置,帮助读者理解RNN在金融数据分析中的应用。
摘要由CSDN通过智能技术生成