学习TensorFlow第九章 9.2.2使用RNN网络拟合回声信号序列-echo模拟

1、定义参数生成样本数据

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
import matplotlib.pyplot as plt

num_epochs=5
total_series_length=50000
truncated_backprop_length=15
state_size=4
num_classes=2
echo_step=3
batch_size=5
num_batches=total_series_length//batch_size//truncated_backprop_length
def generateData():
    x=np.array(np.random.choice(2,total_series_length,p=[0.5,0.5])) #在0和1中选择total_series_length个数
    y=np.roll(x,echo_step) #向右循环移位,将[1111000]变为[0001111]
    y[0:echo_step]=0
    x=x.reshape((batch_size,-1))  #5,10000
    y=y.reshape((batch_size,-1))
    return(x,y)

2、定义占位符处理输入数据

tf.compat.v1.disable_eager_execution()
batchX_placeholder=tf.placeholder(tf.float32,[batch_size,truncated_backprop_length])
batchY_placeholder=tf.placeholder(tf.int32,[batch_size,truncated_backprop_length])
init_state=tf.placeholder(tf.float32,[batch_size,state_size])

#将batchX_Placeholder沿维度为1的轴方向进行拆分
inputs_series=tf.unstack(batchX_placeholder,axis=1)
#truncated_backprop_length个序列
labels_series=tf.unstack(batchY_placeholder,axis=1)

3、定义网络结构

current_state=init_state
predictions_series=[]
losses=[]
#使用一个循环,按照序列逐个输入
tf.disable_v2_behavior()
for current_input,labels in zip(inputs_series,labels_series):
    current_input=tf.reshape(current_input,[batch_size,1])
    #加入初始状态
    input_and_state_concatenated=tf.concat([current_input,current_state],1)
    #next_state=tf.contrib.layers.fully_connected(input_and_state_concatenated,state_size,activation_fn=tf.tanh)
    next_state=tf.keras.layers.Dense(state_size,activation='tanh')(input_and_state_concatenated)
    current_state=next_state
    #logits=tf.contrib.layers.fully_connected(next_state,num_classes,activation_fn=None)
    logits=tf.keras.layers.Dense(num_classes,activation=None)(next_state)
    loss=tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels,logits=logits)
    losses.append(loss)
    predictions=tf.nn.softmax(logits)
    predictions_series.append(predictions)
    
total_loss=tf.reduce_mean(losses)
train_step=tf.train.AdagradOptimizer(0.3).minimize(total_loss)

plot函数定义如下:

def plot(loss_list,predictions_series,batchX,batchY):
    plt.subplot(2,3,1)
    plt.cla()
    plt.plot(loss_list)
    
    for batch_series_idx in range(batch_size):
        one_hot_output_series=np.array(predictions_series)[:,batch_series_idx,:]
        single_output_series=np.array([(1 if out[0]<0.5 else 0) for out in one_hot_output_series])
        plt.subplot(2,3,batch_series_idx+2)
        plt.cla()
        plt.axis([0,truncated_backprop_length,0,2])
        left_offset=range(truncated_backprop_length)
        left_offset2=range(echo_step,truncated_backprop_length+echo_step)
        label1="pass values"
        label2="True echo values"
        label3="Predictions"
        plt.plot(left_offset2,batchX[batch_series_idx,:]*0.2+1.5,"o--b",label=label1)
        plt.plot(left_offset,batchY[batch_series_idx,:]*0.2+0.8,"x--b",label=label2)
        plt.plot(left_offset,single_output_series*0.2+0.1,"o--y",label=label3)
    plt.legend(loc='best')
    plt.draw()
    plt.pause(0.0001)

4、建立session训练数据

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    plt.ion()
    plt.figure()
    plt.show()
    loss_list=[]
    
    for epoch_idx in range(num_epochs):
        x,y=generateData()
        _current_state=np.zeros((batch_size,state_size))
        print("New data,epoch",epoch_idx)
        for batch_idx in range(num_batches):
            start_idx=batch_idx*truncated_backprop_length
            end_idx=start_idx+truncated_backprop_length
            batchX=x[:,start_idx:end_idx]
            batchY=y[:,start_idx:end_idx]
            
            _total_loss,_train_step,_current_state,_predictions_series=sess.run(
                [total_loss,train_step,current_state,predictions_series],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })
            loss_list.append(_total_loss)

5、测试模型及可视化

            if batch_idx%100==0:
                print("Step",batch_idx,"Loss",_total_loss)
                plot(loss_list,_predictions_series,batchX,batchY)
plt.ioff()
plt.show()

以上的代码,在Tensorflow的Spyder(Python3.10)运行,

下面这两行语句:

#next_state=tf.contrib.layers.fully_connected(input_and_state_concatenated,state_size,activation_fn=tf.tanh)
    next_state=tf.keras.layers.Dense(state_size,activation='tanh')(input_and_state_concatenated)

#logits=tf.contrib.layers.fully_connected(next_state,num_classes,activation_fn=None)
    logits=tf.keras.layers.Dense(num_classes,activation=None)(next_state)

在TensorFlow2.0里没有contrib函数,只能改成tf.keras.layers.Dense

for epoch_idx in range(num_epochs):
        x,y=generateData()
        _current_state=np.zeros((batch_size,state_size))
        print("New data,epoch",epoch_idx)
        for batch_idx in range(num_batches):
            start_idx=batch_idx*truncated_backprop_length
            end_idx=start_idx+truncated_backprop_length
            batchX=x[:,start_idx:end_idx]
            batchY=y[:,start_idx:end_idx]
            
            _total_loss,_train_step,_current_state,_predictions_series=sess.run(
                [total_loss,train_step,current_state,predictions_series],
                feed_dict={
                    batchX_placeholder:batchX,
                    batchY_placeholder:batchY,
                    init_state:_current_state
                })
            loss_list.append(_total_loss)
            if batch_idx%100==0:
                print("Step",batch_idx,"Loss",_total_loss)
                plot(loss_list,_predictions_series,batchX,batchY)

在上面的代码_current_state,我打少了下划杠(current_state)

会提示出错如下:

raise TypeError(f'Argument `fetch` = {fetch} has invalid type '

TypeError: Argument `fetch` = [[ 0.3233179   0.26914307 -0.38974735  0.01270094]

  • 3
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值