加载
import torch
from echotorch.datasets.MackeyGlassDataset import MackeyGlassDataset
import echotorch.nn as etnn
import echotorch.utils
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
数据集参数
train_sample_length = 5000
test_sample_length = 1000
n_train_samples = 1
n_test_samples = 1
spectral_radius = 0.9
leaky_rate = 1.0
input_dim = 1
n_hidden = 100
Use CUDA?
use_cuda = False
use_cuda = torch.cuda.is_available() if use_cuda else False
Mackey glass 数据集
mackey_glass_train_dataset = MackeyGlassDataset(train_sample_length, n_train_samples, tau=30)
mackey_glass_test_dataset = MackeyGlassDataset(test_sample_length, n_test_samples, tau=30)
数据加载器
trainloader = DataLoader(mackey_glass_train_dataset, batch_size=1, shuffle=False, num_workers=2)
testloader = DataLoader(mackey_glass_test_dataset, batch_size=1, shuffle=False, num_workers=2)
ESN 细胞
esn = etnn.LiESN(
input_dim=input_dim,
hidden_dim=n_hidden,
output_dim=1,
spectral_radius=spectral_radius,
learning_algo=‘inv’,
leaky_rate=leaky_rate
)
if use_cuda:
esn.cuda()
#end if
对每一批
for data in trainloader:
#输入和输出
inputs, targets = data
#转为变量variable
inputs, targets = Variable(inputs), Variable(targets)
if use_cuda: inputs, targets = inputs.cuda(), targets.cuda()
#积累 xTx and xTy
esn(inputs, targets)
#end for
完成训练
esn.finalize()
训练集数据
dataiter = iter(trainloader)
train_u, train_y = dataiter.next()
train_u, train_y = Variable(train_u), Variable(train_y)
if use_cuda: train_u, train_y = train_u.cuda(), train_y.cuda()
在训练集上预测
y_predicted = esn(train_u)
打印训练集误差
print(u"Train MSE: {}".format(echotorch.utils.mse(y_predicted.data, train_y.data)))
print(u"Test NRMSE: {}".format(echotorch.utils.nrmse(y_predicted.data, train_y.data)))
print(u"")
测试集数据
dataiter = iter(testloader)
test_u, test_y = dataiter.next()
test_u, test_y = Variable(test_u), Variable(test_y)
if use_cuda: test_u, test_y = test_u.cuda(), test_y.cuda()
测试集上预测
y_predicted = esn(test_u)
打印预测误差
print(u"Test MSE: {}".format(echotorch.utils.mse(y_predicted.data, test_y.data)))
print(u"Test NRMSE: {}".format(echotorch.utils.nrmse(y_predicted.data, test_y.data)))
print(u"")