1.MultiInputLSTMCell原理
(1) 该字符不是某个词语结尾的情况:
(2) 该字符是某个词语结尾的情况:
其中:
2.MultiInputLSTMCell代码
import torch.nn as nn
from torch.nn import init,functional
import torch
class MultiInputLSTMCell(nn.Module):
def __init__(self,input_size,hidden_size,use_bias = True):
super(MultiInputLSTMCell,self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.use_bias = use_bias
# 分别与Xjc与h线性变换的矩阵W
self.weight_ih = nn.Parameter(torch.FloatTensor(input_size,3 * hidden_size))
self.weight_hh = nn.Parameter(torch.FloatTensor(hidden_size,3 * hidden_size))
# 分别与Xec与Cbew的线性变换矩阵W
self.alpha_weight_ih = nn.Parameter(torch.FloatTensor(input_size,hidden_size))
self.alpha_weight_hh = nn.Parameter(torch.FloatTensor(hidden_size,hidden_size))
if use_bias:
self.bias = nn.Parameter(torch.FloatTensor(3 * hidden_size))
self.alpha_bias = nn.Parameter(torch.FloatTensor(hidden_size))
else:
self.register_parameter('bias',None)
self.register_parameter('alpha_bias',None)
self.reset_parameters()
def reset_parameters(self):
init.orthogonal_(self.weight_ih.data)
init.orthogonal_(self.alpha_weight_ih.data)
# [hidden_size,hidden_size]
weight_hh_data = torch.eye(self.hidden_size)
# [hidden_size,3 * hidden_size]
weight_hh_data = weight_hh_data.repeat(1,3)
with torch.no_grad():
self.weight_hh.set_(weight_hh_data)
alpha_weight_hh_data = torch.eye(self.hidden_size)
alpha_weight_hh_data = alpha_weight_hh_data.repeat(1,1)
with torch.no_grad():
self.alpha_weight_hh.set_(alpha_weight_hh_data)
if self.use_bias:
init.constant_(self.bias.data,val = 0)
init.constant_(self.alpha_bias,val = 0)
def forward(self,input_,c_input,hx):
# input_:[1,emb_dim]
# c_input:num_words个[1,hidden_size]
# hx:2个[1,hidden_size]
h_0,c_0 = hx
batch_size = h_0.size(0)
assert batch_size == 1
# bias_batch:[1,3 * hidden_size]
bias_batch = (self.bias.unsqueeze(0)).expand(batch_size,*self.bias.size())
# wh_b = h_0 X weight_hh + bias_batch->[1,3 * hidden_size]
wh_b = torch.addmm(bias_batch,h_0,self.weight_hh)
# wi:[1,3 * hidden_size]
wi = torch.mm(input_,self.weight_ih)
# i,o,g:[1,hidden_size]
i,o,g = torch.split(wh_b + wi,split_size_or_sections = self.hidden_size,dim = 1)
i = torch.sigmoid(i)
o = torch.sigmoid(o)
g = torch.tanh(g)
c_num = len(c_input)
# 尾部没有词语注入时:
if c_num == 0:
f = 1 - i
c_1 = f * c_0 + i * g
h_1 = o * torch.tanh(c_1)
else:
# c_input_var:[num_words,hidden_size]
c_input_var = torch.cat(c_input,0)
c_input_var = c_input_var.squeeze(1)
# alpha_wi:(input_ X alpha_weight_ih + alpha_bias).expand(c_num,self.hidden_size) -> [num_words,hidden_size]
alpha_wi = torch.addmm(self.alpha_bias,input_,self.alpha_weight_ih)
# alpha_wh:[num_words,hidden_size]
alpha_wh = torch.mm(c_input_var,self.alpha_weight_hh)
# alpha:[num_words,hidden_size]
alpha = torch.sigmoid(alpha_wi + alpha_wh)
# alpha:[num_words+1,hidden_size]
alpha = torch.exp(torch.cat([i,alpha],0))
# alpha_sum:[hidden_size]
alpha_sum = alpha.sum(0)
# alpha:[num_words+1,hidden_size]
alpha = torch.div(alpha,alpha_sum)
# merge_i_c:[num_words+1,hidden_size]
merge_i_c = torch.cat([g,c_input_var],0)
# c_1:[num_words + 1,hidden_size]
c_1 = merge_i_c * alpha
# c_1:[1,hidden_size]
c_1 = c_1.sum(0).unsqueeze(0)
# h_1:[1,hidden_size]
h_1 = o * torch.tanh(c_1)
return h_1,c_1
emb_dim = 56
hidden_size = 100
model = MultiInputLSTMCell(input_size = emb_dim,hidden_size = hidden_size)
input_ = torch.randn([1,emb_dim])
c_input = [torch.randn([1,hidden_size]) for _ in range(2)]
hx = (torch.randn([1,hidden_size]),torch.randn([1,hidden_size]))
h_1,c_1 = model(input_,c_input,hx)
print('h_1.shape:',h_1.shape)
print('c_1.shape:',c_1.shape)
输出:
h_1.shape: torch.Size([1, 100])
c_1.shape: torch.Size([1, 100])