ESN学习笔记——echotorch(3)mackey_glass_esn

加载

import torch
from echotorch.datasets.MackeyGlassDataset import MackeyGlassDataset
import echotorch.nn as etnn
import echotorch.utils
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader

数据集参数

train_sample_length = 5000
test_sample_length = 1000
n_train_samples = 1
n_test_samples = 1
spectral_radius = 0.9
leaky_rate = 1.0
input_dim = 1
n_hidden = 100

Use CUDA?

use_cuda = False
use_cuda = torch.cuda.is_available() if use_cuda else False

Mackey glass 数据集

mackey_glass_train_dataset = MackeyGlassDataset(train_sample_length, n_train_samples, tau=30)
mackey_glass_test_dataset = MackeyGlassDataset(test_sample_length, n_test_samples, tau=30)

数据加载器

trainloader = DataLoader(mackey_glass_train_dataset, batch_size=1, shuffle=False, num_workers=2)
testloader = DataLoader(mackey_glass_test_dataset, batch_size=1, shuffle=False, num_workers=2)

ESN 细胞

esn = etnn.LiESN(
input_dim=input_dim,
hidden_dim=n_hidden,
output_dim=1,
spectral_radius=spectral_radius,
learning_algo=‘inv’,
leaky_rate=leaky_rate
)
if use_cuda:
esn.cuda()
#end if

对每一批

for data in trainloader:
#输入和输出
inputs, targets = data

#转为变量variable
inputs, targets = Variable(inputs), Variable(targets)
if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()

#积累 xTx and xTy
esn(inputs, targets)
#end for

完成训练

esn.finalize()

训练集数据

dataiter = iter(trainloader)
train_u, train_y = dataiter.next()
train_u, train_y = Variable(train_u), Variable(train_y)

if use_cuda: train_u, train_y = train_u.cuda(), train_y.cuda()

在训练集上预测

y_predicted = esn(train_u)

打印训练集误差

print(u"Train MSE: {}".format(echotorch.utils.mse(y_predicted.data, train_y.data)))
print(u"Test NRMSE: {}".format(echotorch.utils.nrmse(y_predicted.data, train_y.data)))
print(u"")

测试集数据

dataiter = iter(testloader)
test_u, test_y = dataiter.next()
test_u, test_y = Variable(test_u), Variable(test_y)

if use_cuda: test_u, test_y = test_u.cuda(), test_y.cuda()

测试集上预测

y_predicted = esn(test_u)

打印预测误差

print(u"Test MSE: {}".format(echotorch.utils.mse(y_predicted.data, test_y.data)))
print(u"Test NRMSE: {}".format(echotorch.utils.nrmse(y_predicted.data, test_y.data)))
print(u"")

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值