lstm:长短期记忆网络,其特点是有三个门(遗忘门,输入门,输出门)来控制远近记忆h_t对当前状态c_t的影响。
import torch
class LSTM(torch.nn.Module):
def __init__(self,num_vocab,num_hiden,num_classes):
super().__init__()
self.Wxf,self.Whf,self.bf = self.param_def(num_vocab,num_hiden)
self.Wxi,self.Whi,self.bi = self.param_def(num_vocab,num_hiden)
self.Wxc,self.Whc,self.bc = self.param_def(num_vocab,num_hiden)
self.Wxo,self.Who,self.bo = self.param_def(num_vocab,num_hiden)
# 全连接层参数
self.fc = torch.nn.Linear(num_hiden, num_classes)
def param_def(self,num_vocab,num_hiden):
Wx = torch.nn.Parameter(torch.randn(num_vocab,num_hiden))# 设置为参数
Wh = torch.nn.Parameter(torch.randn(num_hiden,num_hiden))
b = torch.nn.Parameter(torch.zeros(num_hiden))
return Wx,Wh,b
def forward(self,Xs,C0,H0):
Hs = [H0]
Cs = [C0]
for X in Xs:
F = torch.sigmoid(X@self.Wxf+Hs[-1] @ self.Whf +self.bf)
I = torch.sigmoid(X@self.Wxi+Hs[-1] @ self.Whi +self.bi)
O = torch.sigmoid(X@self.Wxo+Hs[-1] @ self.Who +self.bo)
C_tiled = torch.tanh(X@self.Wxc+Hs[-1] @ self.Whc +self.bc)
C = F*Cs[-1]+I*C_tiled
H = O*torch.tanh(C)
Cs.append(C)
Hs.append(H)
output = self.fc(H)
return Cs,Hs,output
tips:
1.代码修修补补,上面lstm块的调用还不太会,lstm输入的数据格式为(批数,一批的个数,个值的维度)即(batch_size,seq_length,num_vocab),这里的num_vocab一般也等于num_classes(目标值y中个值的维度)
2.LSTM(长短期记忆)中的hidden_size参数决定了LSTM隐藏层中神经元的数量,即隐藏状态的维度。这个参数的选择对于LSTM网络的性能有很大影响,它决定了网络能够存储和处理的信息量。
后续会不断完善!