[TensorFlow深度学习入门]实战十·用RNN(LSTM)做时间序列预测(曲线拟合)
% matplotlib inline
import os
os. environ[ "KMP_DUPLICATE_LIB_OK" ] = "TRUE"
import numpy as np
import matplotlib. pyplot as plt
import tensorflow as tf
lr = 0.002
training_iters = 500
batch_size = 30
n_inputs = 1
n_steps = 10
n_hidden_units = 16
n_classes = 1
def get_data ( x, w, b) :
c, r = x. shape
y = np. sin( w* x) + b + ( 0.01 * ( 2 * np. random. rand( c, r) - 1 ) )
return ( y)
xs = np. arange( 0 , 3 , 0.01 ) . reshape( - 1 , 1 )
ys = get_data( xs, 5 , 0.5 )
datas = [ ]
for i in range ( len ( xs) - 11 ) :
datas. append( ys[ i: i+ 11 ] )
datas = np. array( datas) . reshape( - 1 , 11 )
print ( datas. shape)
plt. title( "curve" )
plt. plot( ys)
plt. show( )
(289, 11)
x = tf. placeholder( tf. float32, [ None , n_steps* n_inputs] )
y = tf. placeholder( tf. float32, [ None , n_classes] )
weights = {
'out' : tf. Variable( tf. random_normal( [ n_hidden_units, n_classes] ) )
}
biases = {
'out' : tf. Variable( tf. constant( 0.1 , shape= [ 1 , n_classes] ) )
}
def RNN ( X, weights, biases) :
X = tf. reshape( X, [ - 1 , n_steps, n_inputs] )
lstm_cell = tf. nn. rnn_cell. BasicLSTMCell( n_hidden_units, forget_bias = 1.0 , state_is_tuple = True )
_init_state = lstm_cell. zero_state( 289 , dtype= tf. float32)
output, states = tf. nn. dynamic_rnn( lstm_cell, X, initial_state= _init_state, time_major= False )
print ( output)
result = tf. matmul( output[ : , - 1 , : ] , weights[ "out" ] + biases[ "out" ] )
return ( result)
pred = RNN( x, weights, biases)
cost = tf. reduce_mean( tf. square( pred- y) )
train_op = tf. train. AdamOptimizer( lr) . minimize( cost)
init = tf. global_variables_initializer( )
with tf. Session( ) as sess:
sess. run( init)
srun = sess. run
for t in range ( training_iters+ 1 ) :
srun( train_op, { x: datas[ 0 : 289 , : 10 ] , y: datas[ 0 : 289 , 10 : 11 ] } )
if ( t% 10 == 0 ) :
loss_val = srun( cost, { x: datas[ 0 : 289 , : 10 ] , y: datas[ 0 : 289 , 10 : 11 ] } )
print ( t, loss_val)
y_val = srun( pred, { x: datas[ 0 : 289 , : 10 ] } ) . reshape( - 1 , 1 )
plt. title( "pre" )
plt. plot( y_val)
plt. show( )
WARNING:tensorflow:From <ipython-input-2-130bdeb48069>:20: BasicLSTMCell.__init__ (from tensorflow.python.ops.rnn_cell_impl) is deprecated and will be removed in a future version.
Instructions for updating:
This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').
Tensor("rnn/transpose_1:0", shape=(289, 10, 16), dtype=float32)
0 2.7168088
10 1.0216647
20 0.29450005
30 0.16755253
...
470 0.0010900635
480 0.001046965
490 0.001006315
500 0.0009679485