循环神经网络小笔记1-1

#构建循环人工神经网络
import numpy as np
import matplotlib.pyplot as plt
import neurolab as nl

# 准备数据集
# 用np生成一些序列数据,这个序列数据有四段
def get_data(num_points):
    '''建立波形数据集,这个数据集含有四段,每一段的数据点数为num_points'''
    wave1 = 0.49 * np.sin(np.arange(0, num_points))#生成一维数组
    wave2 = 3.62 * np.sin(np.arange(0, num_points))
    wave3 = 1.2 * np.sin(np.arange(0, num_points))
    wave4 = 4.6 * np.sin(np.arange(0, num_points))
    # 每一段数据的幅度不同 分别是1,2,3.1,0.9
    amp_1 = np.ones(num_points)
    amp_2 = 2 + np.zeros(num_points)
    amp_3 = 3.1 + np.zeros(num_points)
    amp_4 = 0.9 + np.zeros(num_points)
    # 4行num_points列,转变为:4*num_points行,一列,即为整个序列
    wave = np.array([wave1, wave2, wave3, wave4]).reshape(num_points * 4, 1)
    amp = np.array([amp_1, amp_2, amp_3, amp_4]).reshape(num_points * 4, 1)

    return wave, amp


def visualize_output(nn, num_points_test):
    wave, amp = get_data(num_points_test)
    output = nn.sim(wave)
    plt.plot(amp.reshape(num_points_test * 4))
    plt.plot(output.reshape(num_points_test * 4))

if __name__ == '__main__':
    #训练数据的点数50
    num_points = 50
    wave, amp = get_data(num_points)
    # plt.figure()
    # plt.plot(wave)
    # plt.plot(amp)
    # plt.xlabel('Dimension 1')
    # plt.ylabel('Dimension 2')
    # plt.title('Input data')
    # plt.show()

    #  Create network with 2 layers
    nn = nl.net.newelm([[-3, 3]], [9, 1], [nl.trans.TanSig(), nl.trans.PureLin()])
    # Set initialized functions and init
    nn.layers[0].initf = nl.init.InitRand([-0.1, 0.1], 'wb')
    nn.layers[1].initf = nl.init.InitRand([-0.1, 0.1], 'wb')
    nn.init()
    # Train network
    error_progress = nn.train(wave, amp, epochs=1200, show=100, goal=0.01)
    # Simulate network
    # 假设以训练集所用的wave为样本,那么看看得到的结果和期望期的差异
    output = nn.sim(wave)


    #子图,误差改进过程
    plt.subplot(221)
    plt.plot(error_progress)
    plt.xlabel('Number of epochs')
    plt.ylabel('Error')
    #子图,amp为True,output为网络仿真得到的Predicted
    plt.subplot(222)
    plt.plot(amp.reshape(num_points * 4)) #True
    plt.plot(output.reshape(num_points * 4)) #Predicted
    plt.legend(['Original', 'Predicted'])
    #子图,误差改进过程
    plt.subplot(223)
    plt.plot(error_progress)
    plt.xlabel('Number of epochs')
    plt.ylabel('Error')
    #子图,amp为True,output为网络仿真得到的Predicted
    plt.subplot(224)
    plt.plot(amp.reshape(num_points * 4)) #True
    plt.plot(output.reshape(num_points * 4)) #Predicted
    plt.legend(['Original', 'Predicted'])


    # 生成新的数据集,再进行测试
    #可以发现,对于新产生的序列数据,该模型也能够大体预测出来。
    plt.figure()
    plt.subplot(211)
    #测试数据的点数83
    visualize_output(nn, 100)
    plt.xlim([0, 400])
    plt.subplot(212)
    #测试数据的点数36
    visualize_output(nn, 25)
    plt.xlim([0, 200])
    plt.show()




















训练波形数据wave和amp标签的关系

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值