1.代码原理
(1).以时间步t的流程图
(2).举个栗子
MultiInputLSTMCell,WordLSTMCell详见前面博客!!!
2.代码实践
import torch
import torch.nn as nn
import numpy as np
import torch.autograd as autograd
from LatticeLSTM.model.latticelstm import MultiInputLSTMCell,WordLSTMCell
def init_list_of_objects(size):
list_of_objects = list()
for i in range(size):
list_of_objects.append(list())
return list_of_objects
# 转化为在字符尾部显示词汇信息
def covert_forward_gaz_to_backward(forward_gaz):
length = len(forward_gaz)
backward_gaz = init_list_of_objects(length)
for idx in range(length):
if forward_gaz[idx]:
assert len(forward_gaz[idx]) == 2
num = len(forward_gaz[idx][0])
for idy in range(num):
the_id = forward_gaz[idx][0][idy]
the_length = forward_gaz[idx][1][idy]
the_pos = idx + the_length - 1
if backward_gaz[new_pos]:
backward_gaz[new_pos][0].append(the_id)
backward_gaz[new_pos][1].append(the_length)
else:
backward_gaz[new_pos] = [[the_id],[the_length]]
return backward_gaz
class LatticeLSTM(nn.Module):
def __init__(self,input_dim,hidden_dim,word_drop,word_alphabet_size,word_emb_dim,pretrain_word_emb = None,left2right = True,fix_word_emb = True,gpu = True,use_bias = True):
super(LatticeLSTM,self).__init__()
skip_direction = 'forward' if left2right else 'backward'
print('Build LatticeLSTM...',skip_direction,',Fix emb:',fix_word_emb,'gaz drop:',word_drop)
self.gpu = gpu
self.hidden_dim = hidden_dim
self.word_emb = nn.Embedding(word_alphabet_size,word_emb_dim)
if pretrain_word_emb is not None:
print('load pretrain word emb...',pretrain_word_emb.shape)
self.word_emb.weight.data.copy_(torch.from_numpy(pretrain_word_emb))
else:
self.word_emb.weight.data.copy_(torch.from_numpy(self.random_embedding(word_alphabet_size,word_emb_dim)))
if fix_word_emb:
self.word_emb.weight.requires_grad = False
self.word_dropout = nn.Dropout(word_drop)
self.rnn = MultiInputLSTMCell(input_dim,hidden_dim)
self.word_rnn = WordLSTMCell(word_emb_dim,hidden_dim)
self.left2right = left2right
if self.gpu:
self.rnn = self.rnn.cuda()
self.word_emb = self.word_emb.cuda()
self.word_dropout = self.word_dropout.cuda()
self.word_rnn = self.word_rnn.cuda()
def random_embedding(self,vocab_size,embedding_dim):
pretrain_emb = np.empty([vocab_size,embedding_dim])
scale = np.sqrt(3.0 / embedding_dim)
for index in range(vocab_size):
pretrain_emb[index,:] = np.random.uniform(-scale,scale,[1,embedding_dim])
return pretrain_emb
def forward(self,input_,skip_input_list,hidden = None):
# input_:[1,seq_len,embed_dim]
# skip_input_list:([[],[[25,13],[2,3]]....seq_len个],False),24,13为word的id,2,3为word的长度
# skip_input:[[],[[25,13],[2,3]]....seq_len个]
skip_input = skip_input_list[0]
if not self.left2right:
skip_input = covert_forward_gaz_to_backward(skip_input)
# input_:[seq_len,1,embed_dim]
input_ = input_.transpose(1,0)
seq_len = input_.size(0)
batch_size = input_.size(1)
assert batch_size == 1
hidden_out = []
memory_out = []
if hidden:
(hx,cx) = hidden
else:
# hx,cx:[1,hidden_dim]
hx = autograd.Variable(torch.zeros(batch_size,self.hidden_dim))
cx = autograd.Variable(torch.zeros(batch_size,self.hidden_dim))
if self.gpu:
hx = hx.cuda()
cx = cx.cuda()
id_list = list(range(seq_len))
if not self.left2right:
id_list = list(reversed(id_list))
# [[]...seq_len个]
# 存WordLSTMCell输出的表示
input_c_list = init_list_of_objects(seq_len)
for t in id_list:
# hx,cx:[1,hidden_size]
(hx,cx) = self.rnn(input_[t],input_c_list[t],(hx,cx))
hidden_out.append(hx)
memory_out.append(cx)
if skip_input[t]:
# 该文中匹配t字符为结尾的词语个数
matched_num = len(skip_input[t][0])
# 找出该位置的所有词语
with torch.no_grad():
word_var = autograd.Variable(torch.LongTensor(skip_input[t][0]))
if self.gpu:
word_var = word_var.cuda()
word_emb = self.word_emb(word_var)
word_emb = self.word_dropout(word_emb)
# hx,cx:[1,hidden_size]
# ct:[num_words,hidden_size]
ct = self.word_rnn(word_emb,(hx,cx))
assert ct.size(0) == len(skip_input[t][1])
# 记录该词的结束位置,为后续将词信息注入
for idx in range(matched_num):
length = skip_input[t][1][idx]
if self.left2right:
input_c_list[t + length - 1].append(ct[idx,:].unsqueeze(0))
else:
input_c_list[t - length + 1].append(ct[idx,:].unsqueeze(0))
if not self.left2right:
hidden_out = list(reversed(hidden_out))
memory_out = list(reversed(memory_out))
# output_hidden,output_memory:[max_len,hidden_size]
output_hidden,output_memory = torch.cat(hidden_out,0),torch.cat(memory_out,0)
return output_hidden.unsqueeze(0),output_memory.unsqueeze(0)
input_dim,hidden_dim,word_drop,word_alphabet_size,word_emb_dim = 50,100,0.2,1000,30
model = LatticeLSTM(input_dim,hidden_dim,word_drop,word_alphabet_size,word_emb_dim)
max_len = 10
input_ = torch.randn([1,max_len,input_dim]).cuda()
skip_input_list = ([[],[[10,100,99],[2,3,4]],[[10,120,100],[2,3,4]],[[10,120,100],[2,3,4]],[[10,120,100],[2,3,4]],[],[],[],[],[]],True)
h,m = model(input_,skip_input_list)
print('h.shape:',h.shape)
print('m.shape:',m.shape)
输出:
Build LatticeLSTM… forward ,Fix emb: True gaz drop: 0.2
h.shape: torch.Size([1, 10, 100])
m.shape: torch.Size([1, 10, 100])