循环神经网络十分擅长处理时间相关的数据,下面我们就通过输入sin函数,输出cos函数来实际应用
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import optim
import numpy as np
from matplotlib import pyplot as plt
import matplotlib.animation
import math, random
#定义超参数
time_step = 10 #RNN时序步长
input_size = 1 #RNN的输入维度
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
h_size = 64 # RNN隐藏单元个数
epochs = 300
h_state = None #隐藏层状态
#pytorch没有π这个常量,所有操作都是用numpy完成
steps = np.linspace(0, np.pi*2, 256, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
#可视化数据
plt.figure(1) #生成一个画板
plt.suptitle('Sin and Cos', fontsize='18') #图名和字体
plt.plot(steps, y_np, 'r-', label='target(cos)') #
plt.plot(steps, x_np, 'y-', label='input(sin)')
plt.legend(loc='best') #线条说明放在最佳的地方,即坐标面内的数据图表最少的位置
plt.show()
class RNN(nn.Module):
def __init__(self):
super(RNN, self).__init__()
self.rnn = nn.RNN(
input_size = input_size,
hidden_size = h_size,
num_layers = 1,
batch_first = True,
)
self.out = nn.Linear(h_size, 1)
def forward(self, x, h_state):
r_out, h_state = self.rnn(x,h_state)
outs = []
for time_step in range(r_out.size(1)):
outs.append(self.out(r_out[:,time_step, :]))
return torch.stack(outs,dim=1),h_state
rnn = RNN().to(device)
optimizer = torch.optim.Adam(rnn.parameters())
criterion = nn.MSELoss()
rnn.train()
plt.figure(2)
for step in range(epochs):
start, end = step * np.pi, (step+1)*np.pi # 一个时间周期
steps = np.linspace(start, end, time_step, dtype=np.float32)
x_np = np.sin(steps)
y_np = np.cos(steps)
x = torch.from_numpy(x_np[np.newaxis, :, np.newaxis]) # shape (batch, time_step, input_size)
y = torch.from_numpy(y_np[np.newaxis, :, np.newaxis])
x = x.to(device)
prediction, h_state = rnn(x, h_state) # rnn output
# 这一步非常重要
h_state = h_state.data # 重置隐藏层的状态, 切断和前一次迭代的链接
loss = criterion(prediction.cpu(), y)
# 这三行写在一起就可以
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (step+1)%20 == 0: #每训练20个批次可视化一下效果,并打印一下loss
print("epochs: {},Loss:{:4f}".format(step,loss))
plt.plot(steps, y_np.flatten(), 'r-', label='cos')
plt.plot(steps, prediction.cpu().data.numpy().flatten(), 'b-', label='presiction')
plt.legend(loc='best')
plt.draw()
plt.pause(0.01)