利用基本的LSTM循环网络实现对 “international-airline-passengers.csv“ 的预测 ,该数据是 1949到1960共12年,每年12个月的数据,一共 144 个数据。在 SeriesPredictor 类中定义了模型构建、模型训练、模型测试的方法来实现预测,plot_result 方法实现了预测结果的折线图绘制。
import numpy as np
import tensorflow as tf
from tensorflow.contrib import rnn
import matplotlib.pyplot as plt
class SeriesPredictor:
def __init__(self, input_dim, seq_size, hidden_dim):
# Hyperparameters
self.input_dim = input_dim
self.seq_size = seq_size
self.hidden_dim = hidden_dim
# Weight variables and input placeholders
self.W_out = tf.Variable(tf.random_normal([hidden_dim, 1]), name='W_out')
self.b_out = tf.Variable(tf.random_normal([1]), name='b_out')
self.x = tf.placeholder(tf.float32, [None, seq_size, input_dim])
self.y = tf.placeholder(tf.float32, [None, seq_size])
# Cost optimizer
self.cost = tf.reduce_mean(tf.square(self.model() - self.y))
self.train_op = tf.train.AdamOptimizer(learning_rate=0.01).minimize(self.cost)
# Auxiliary ops
self.saver = tf.train.Saver()
def model(self):
"""
:param x: inputs of size [T, batch_size, input_size]
:param W: matrix of