【TensorFlow】LSTM(使用TFLearn预测正弦sin函数)

本文展示了如何使用TFLearn和LSTM神经网络预测正弦函数。通过离散化sin函数,创建数据集,并用TFLearn简化TensorFlow模型的构建。在代码实现过程中,作者遇到了多个错误,包括tf.unpack替换为tf.unstack,rnn_cell的变化,以及动态RNN的使用等,并逐一解决了这些问题。最终,LSTM模型成功预测出正弦曲线,证明了其在时间序列预测上的有效性。
摘要由CSDN通过智能技术生成

项目已上传至 GitHub —— sin_pre

数据生成


因为标准的循环神经网络模型预测的是离散的数值,所以需要将连续的 sin 函数曲线离散化

所谓离散化就是在一个给定的区间 [0,MAX] 内,通过有限个采样点模拟一个连续的曲线,即间隔相同距离取点

采样用的是 numpy.linspace() 函数,它可以创建一个等差序列,常用的参数有三个

  • start:起始值
  • stop:终止值,不包含在内
  • num:数列长度,默认为 50

然后使用一个 generate_data() 函数生成输入和输出,序列的第 i 项和后面的 TIMESTEPS-1 项合在一起作为输入,第 i + TIMESTEPS 项作为输出

TFLearn使用


TFlearn 对训练模型进行了一些封装,使 TensorFlow 更便于使用,如下示范了 TFLearn 的使用方法

from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat

learn = tf.contrib.learn

# 建立深层循环网络模型
regressor = SKCompat(learn.Estimator(model_fn=lstm_model, model_dir='model/'))

# 调用fit函数训练模型
regressor.fit(train_x, train_y, batch_size=BATCH_SIZE, steps=TRAINGING_STEPS)

# 使用训练好的模型对测试集进行预测
predicted = [[pred] for pred in regressor.predict(test_x)]

完整代码


该代码实现自《TensorFlow:实战Google深度学习框架》

整个代码的结构如下

  • lstm_model() 类用于创建 LSTM 网络并返回一些结果
  • LstmCell() 函数用于创建单层 LSTM 结构,防止 LSTM 参数名称一样
  • generate_data() 函数用于创建数据集

由于原书中的代码是基于 1.0,而我用的是 1.5,所以出现了很多错误,我将所遇到的错误的解决方法都记录在了文末

import numpy as np
import tensorflow as tf
import matplotlib as mpl
from matplotlib import pyplot as plt
from tensorflow.contrib.learn.python.learn.estimators.estimator import SKCompat

# TensorFlow的高层封装TFLearn
learn = tf.contrib.learn

# 神经网络参数
HIDDEN_SIZE = 30  # LSTM隐藏节点个数
NUM_LAYERS = 2  # LSTM层数
TIMESTEPS = 10  # 循环神经网络截断长度
BATCH_SIZE = 32  # batch大小

# 数据参数
TRAINING_STEPS = 3000  # 训练轮数
TRAINING_EXAMPLES = 10000  # 训练数据个数
TESTING_EXAMPLES = 1000  # 测试数据个数
SAMPLE_GAP = 0.01  # 采样间隔


def generate_data(seq):
    # 序列的第i项和后面的TIMESTEPS-1项合在一起作为输入,第i+TIMESTEPS项作为输出
    X = []
    y = []
    
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值