做个笔记,以防忘记
import torch
import torch.nn as nn
class myLstm(nn.Module):
# input_sz为字向量维度
def __init__(self, input_sz, hidden_sz):
super().__init__()
self.input_size = input_sz
self.hidden_size = hidden_sz
self.U_i = nn.Linear(input_sz, hidden_sz)
self.V_i = nn.Linear(hidden_sz, hidden_sz)
# f_t
self.U_f = nn.Linear(input_sz, hidden_sz)
self.V_f = nn.Linear(hidden_sz, hidden_sz)
# c_t
self.U_g = nn.Linear(input_sz, hidden_sz)
self.V_g = nn.Linear(hidden_sz, hidden_sz)
# o_t
self.U_o = nn.Linear(input_sz, hidden_sz)
self.V_o = nn.Linear(hidden_sz, hidden_sz)
def forward(self, x, init_states=None):
# x.size() = batch_size * len * embedding_dim
x = x.float()
bs, seq_sz, _ = x.size()
hidden_seq = []
# 初始化隐变量h0,c0
if init_states is None:
h_t, c_t = (
torch.zeros(bs, self.hidden_size),
torch.zeros(bs, self.hidden_size)
)
else:
h_t, c_t = init_states
# 对所有批次中每句第t个字
for t in range(seq_sz):
x_t = x[:, t, :]
# 获得i_t
i_t = torch.sigmoid(self.U_i(x_t) + self.V_i(h_t))
# 获得f_t
f_t = torch.sigmoid(self.U_f(x_t) + self.V_f(h_t))
# 获得g_t
g_t = torch.tanh(self.U_g(x_t) + self.V_g(h_t))
# 获得o_t
o_t = torch.sigmoid(self.U_o(x_t) + self.V_o(h_t))
# 以上四个都以x_t和h_t为输入量,且输出都为 batch_size * hidden_size
# 此处后面的c_t实际为c_t-1,然后获得当前c_t
c_t = f_t * c_t + i_t * g_t
# 用o_t和得到的c_t相乘得当前h_t
h_t = o_t * torch.tanh(c_t)
# 记录每次的h_t的值
hidden_seq.append(h_t.unsqueeze(0))
# 将此时的hidden_seq(为含seq_len个元素的list, 每个元素为(1, batch_size, hidden_size))拼接为(seq_len, batch_size, hidden_size)
hidden_seq = torch.cat(hidden_seq, dim=0)
# 转换为(batch_size, seq_len, hidden_size)
hidden_seq = hidden_seq.transpose(0, 1).contiguous()
return hidden_seq, (h_t, c_t)
if __name__ == '__main__':
# 测试
batch_size = 100
hidden_size = 64
embedding_dim = 32
min_id = 1
max_id = 3000
model = myLstm(embedding_dim, hidden_size)
input = torch.randint(min_id, max_id, (batch_size, 50, embedding_dim))
print(model)
print(model(input))