FORCE learning

参考文献

  • Generating Coherent Patterns of Activity from Chaotic Neural Networks
    在这里插入图片描述

  • Collective dynamics of rate neurons for supervised learning in a reservoir computing system在这里插入图片描述

工作原理

在这里插入图片描述
将递归最小二乘法(RLS)用于储备池的输出权重在线更新

import numpy as np
import random
import os
import time
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib qt5
from tqdm.notebook import tqdm

def set_seed(seed=None):
    pass
    
system_name = 'lorentz'
Y = np.loadtxt('dataset/'+system_name+'.txt', delimiter=',').T
T = Y.shape[1]

inSize = Y.shape[0]
outSize = inSize 
resSize = 400 
a = 0.9
K = 1
reg = 1e-6
input_scaling = 0.5

train_time = 1000
seed = 44  
set_seed(seed)
Win = (np.random.rand(resSize,1+inSize)-0.5) * input_scaling
W = np.random.rand(resSize,resSize)-0.5
rhoW = np.sqrt(max(abs(np.linalg.eig(W@W.T)[0]))) # maximal eigenvalue
W = W/rhoW*(K-1+a)/a 
Wout = np.random.rand(outSize, 1+inSize+resSize) - 0.5


x = np.zeros((resSize,1))
S = np.zeros((1+inSize+resSize,T))
Z = np.zeros([outSize, T])
Error = np.zeros([inSize, T])
Time = list(range(T))
P = np.eye(S.shape[0])/reg

f, ax =plt.subplots(nrows=1,ncols=2, figsize=(20,5))
ax[0].set_xlabel('Time')
ax[0].set_ylabel('f(t)')
line1, line2, line3 = None, None, None
plt.grid(True)
plt.ion()

for t in tqdm(range(1,T)):
    if t < train_time:
        u = Y[:,t-1:t] + np.random.randn(inSize,1)*0.001
    else:
        u = z
    x = (1-a) * x + a * np.tanh(Win @ np.vstack((1, u))) + W @ x)
    s =  np.vstack((1,u,x))
    z = Wout @ s
    y = Y[:,t:t+1]
    dy = z - y
    P -= P @ s @ s.T @ P/(1+s.T @ P @ s)
    dWout = dy @ (P @ s).T    
    
    if t < train_time:
        Wout -= dWout
    
    S[:,t:t+1] = s
    Z[:,t:t+1] = z
    Error[:,t:t+1] = np.abs(dy)
    
    """
    画图
	"""
    if line1 is None:
        line1 = ax[0].plot(Time[1:t],Z[0,1:t],'-g',marker='*', label='output')[0]
        line2 = ax[0].plot(Time[1:t],Y[0,1:t],'-r',marker='*', label='target')[0]
        line3 = ax[1].plot(Time[1:t],Error[0,1:t],'-k',marker='.', label='error')[0]
        ax[0].legend(loc='upper left')
        ax[1].legend(loc='upper left')
    
    line1.set_xdata(Time[1:t])
    line1.set_ydata(Z[0,1:t])
    line3.set_xdata(Time[1:t])
    line3.set_ydata(Error[0,1:t])
    line2.set_xdata(Time[1:t])
    line2.set_ydata(Y[0,1:t])
    
    ax[0].set_xlim([t-300,t+1])
    ax[0].set_ylim([0,1])
    ax[1].set_xlim([t-300,t+1])
    ax[1].set_yscale("log")
    ax[1].set_ylim([1e-6,10])
    plt.pause(0.001)
  • 训练阶段: t < train_time = 1000 t < \text{train\_time} = 1000 t<train_time=1000
    在这里插入图片描述
  • 训练结束后运行, t > train_time = 1000 t > \text{train\_time} = 1000 t>train_time=1000

在这里插入图片描述
在这里插入图片描述

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值