想要跑LSTM具体的网络结构,在网上找了很久,都是直接调用nn.LSTM模块。
从调用LSTM模块到自定义LSTM,不知道为什么训练速度慢了很多
import torch
import torch.nn as nn
import math
class LSTM(nn.Module):
def __init__(self, input_size, hidden_size):
super(LSTM, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
# 输入门i_t
self.U_i = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.V_i = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_i = nn.Parameter(torch.Tensor(hidden_size))
# f_t
self.U_f = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.V_f = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_f = nn.Parameter(torch.Tensor(hidden_size))
# c_t
self.U_c = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.V_c = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_c = nn.Parameter(torch.Tensor(hidden_size))
# o_t
self.U_o = nn.Parameter(torch.Tensor(input_size, hidden_size))
self.V_o = nn.Parameter(torch.Tensor(hidden_size, hidden_size))
self.b_o = nn.Parameter(torch.Tensor(hidden_size))
self.init_weights()
def init_weights(self):
stdv = 1.0 / math.sqrt(self.hidden_size)
for weight in self.parameters():
weight.data.uniform_(-stdv, stdv)
def forward(self, lstm_input, init_states=None):
batch_size, seq_len = lstm_input.size(0), lstm_input.size(1)
hidden_seq = []
if init_states is None:
h_t, c_t = (
torch.zeros(batch_size, self.hidden_size).to(lstm_input.device),
torch.zeros(batch_size, self.hidden_size).to(lstm_input.device)
)
else:
h_t, c_t = init_states
for t in range(seq_len):
x_t = lstm_input[:, t, :]
# 更新门组件及内部候选状态(Tips:Pytorch中@用于矩阵相乘,*用于逐个元素相乘)
i_t = torch.sigmoid(x_t @ self.U_i + h_t @ self.V_i + self.b_i)
f_t = torch.sigmoid(x_t @ self.U_f + h_t @ self.V_f + self.b_f)
g_t = torch.tanh(x_t @ self.U_c + h_t @ self.V_c + self.b_c)
o_t = torch.sigmoid(x_t @ self.U_o + h_t @ self.V_o + self.b_o)
c_t = f_t * c_t + i_t * g_t
h_t = o_t * torch.tanh(c_t)
hidden_seq.append(h_t.unsqueeze(0))
hidden_seq = torch.cat(hidden_seq, dim=0)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
如果代码有什么问题欢迎指出