句向量
句向量
为了得到完成的句子理解,我们会把词向量转化为句向量,Encoding方法就可以实现这个操作;
sequence2sequence
一、sequence2sequence是什么?
sequence2sequence是一种序列到序列的模型,将接收到的一个序列输出得到另一个序列;一般需要encoding【逐项处理输入】— decoding【解码器逐项输出序列】的过程;
应用:阅读理解、文本摘要、闲聊系统、看图说话、语言处理、机器翻译(最佳)
二、步骤
input—>encoder—>context—>decoder—>output
1.input—>encoder【压缩】RNN方法
每次的输出都基于本次的输入和前一个向量,所以是一个逐项进行的过程;
2.encoder—>context
context起到一个桥梁的作用,是一种计算机可以理解的向量表达形式,是encoder的最后一个单词的输出;
3.context—>decoder【解压】RNN方法
有一个 < start> 和< end>标志;
将前一个词的输出作为本次的输入,再根据前一个向量,最后得到本次的输出;
前一次的输出的是一个大小为词库大小的分布情况,我们常选择最好的那个做作为输出(这是一种贪心算法了,不能得到一个全局最优解,于是我们对其进行优化)
优化:
(1)exhaustic search:将得到的所有分布情况都考虑进来,是一种时间复杂度很高的方法
(2)beam search:每次保留top3/5的结果,虽然也是一种贪心算法,但是前面两种方法的中和,算是优良了;它的复杂度大概是O(T),T是sequence长度;
【注】这里会对P取log防止溢出;取负号,因为我们一般习惯求最小值;一般会取一个加和平均值,避免它总是偏向选择长度较短的句子
三、总结
我们使用的是一种RNN模型来产生句向量的embedding,但是这种方法存在着一些弊端:
(1)是一种串行的方法,速度慢
(2)long-term dependency
(3)shallow moudel (横向上是deep module ,但是纵向是shallow module)
基于以上问题,提出attention
sequence2sequence attention
优点:
(1)解决了梯度耗尽的问题
(2)可解释性强(在出错时,可以根据注意力💡找到可能错误的地方)
encoder部分不变
decode中每次h在生成输出时,要考虑encode的输出部分,h·g可以得到一个权重,然后进行normalization(可以使用softmax等进行),使得所有加和后是1,这样得到的C‘作为context vector,来生成y‘(本次的输出)
encode部分:
# encoder
self.enc_embeddings = nn.Embedding(enc_v_dim, emb_dim)
self.enc_embeddings.weight.data.normal_(0, 0.1)
self.encoder = nn.LSTM(emb_dim, units, 1, batch_first=True)
def encode(self, x): #[32,8]
embedded = self.enc_embeddings(x) # [n, step, emb]=[32,8,16]
hidden = (torch.zeros(1, x.shape[0], self.units), torch.zeros(1, x.shape[0], self.units))
o, (h, c) = self.encoder(embedded, hidden) # [n, step, units]=[32,8,32], [num_layers * num_directions = 1 , 32, 32]*2
return o, h, c
decode部分:
# decoder
self.dec_embeddings = nn.Embedding(dec_v_dim, emb_dim)
self.attn = nn.Linear(units, units)
self.decoder_cell = nn.LSTMCell(emb_dim, units)
self.decoder_dense = nn.Linear(units * 2, dec_v_dim)
def inference(self, x, return_align=False):
self.eval()
o, hx, cx = self.encode(x) # [n, step, units]=[1,8,32], [1, 1, 32] * 2
hx, cx = hx[0], cx[0] # [n, units] = [1,32]
#初始化start
start = torch.ones(x.shape[0], 1) # [n, 1] = [1,1]
start[:, 0] = torch.tensor(self.start_token)
start = start.type(torch.LongTensor)
dec_emb_in = self.dec_embeddings(start) # [n, 1, emb_dim] = [1,1,16]
dec_emb_in = dec_emb_in.permute(1, 0, 2) # [1, n, emb_dim] = [1,1,16]
dec_in = dec_emb_in[0] # [n, emb_dim] [1,16]
output = []
for i in range(self.max_pred_len): # 进行11次
attn_prod = torch.matmul(self.attn(hx.unsqueeze(1)), o.permute(0, 2, 1)) # h[1,1,32]与encode输出[1,32,8]进行点积[1, 1, 8]:实现了h与各个输出进行点积运算
att_weight = softmax(attn_prod, dim=2) # 得到一个权重[n, 1, step] = [1,1,8]
context = torch.matmul(att_weight, o) # 权重矩阵与输出o点积得到context[n, 1, units] = [1,1,32]
hx, cx = self.decoder_cell(dec_in, (hx, cx)) # [1, 32]更新的解码器隐藏状态
hc = torch.cat([context.squeeze(1), hx], dim=1) # 进行拼接:[1,32+32]=[n, units *2]=[1,64]
result = self.decoder_dense(hc) # [n, dec_v_dim] = [1,27]
#将上一次的结果作为下一次的输入
result = result.argmax(dim=1).view(-1, 1) #采用贪心算法选出最终输出结果[1,1]
dec_in = self.dec_embeddings(result).permute(1, 0, 2)[0] #[1,1]--[1,1,16]--[1,1,16]--[1,16]
output.append(result)
output = torch.stack(output, dim=0) #[11.1,1]
self.train()
return output.permute(1, 0, 2).view(-1, self.max_pred_len) #[1,11,1]---[1,11]
计算损失函数:
def train_logit(self, x, y):
o, hx, cx = self.encode(x) # [32, 8, 32], [1, 32, 32] * 2
hx, cx = hx[0], cx[0] # [n, units] = [32,32]
dec_in = y[:, :-1] # [n, step] = [32,10]
dec_emb_in = self.dec_embeddings(dec_in) # [n, step, emb_dim] = [32,10,16]
dec_emb_in = dec_emb_in.permute(1, 0, 2) # [step, n, emb_dim] = [10,32,16]
output = []
for i in range(dec_emb_in.shape[0]): # 进行10次
attn_prod = torch.matmul(self.attn(hx.unsqueeze(1)), o.permute(0, 2, 1)) # h[32,1,32]与encode输出[32,32,8]进行点积[32, 1, 8]:实现了h与各个输出进行点积运算
att_weight = softmax(attn_prod, dim=2) # 得到一个权重[n, 1, step] = [32,1,8]
context = torch.matmul(att_weight, o) # 权重矩阵与输出o点积得到context[n, 1, units] = [32,1,32]
hx, cx = self.decoder_cell(dec_emb_in[i], (hx, cx)) # [32, 32]更新的解码器隐藏状态
hc = torch.cat([context.squeeze(1), hx], dim=1) # 进行拼接:[32,32+32]=[n, units *2]=[32,64]
result = self.decoder_dense(hc) # [n, dec_v_dim] = [32,27]
output.append(result)
output = torch.stack(output, dim=0) # [step, n, dec_v_dim]=[10,32,27]
return output.permute(1, 0, 2) # [n, step, dec_v_dim] = [32,10,27]
def step(self, x, y):
self.opt.zero_grad()
batch_size = x.shape[0]
logit = self.train_logit(x, y) #得到预测输出结果
dec_out = y[:, 1:]
loss = cross_entropy(logit.reshape(-1, self.dec_v_dim), dec_out.reshape(-1))
loss.backward()
self.opt.step()
return loss.detach().numpy()
输出看一下结果:
def train():
dataset = utils.DateData(4000)
print("Chinese time order: yy/mm/dd ", dataset.date_cn[:3], "\nEnglish time order: dd/M/yyyy", dataset.date_en[:3])
print("Vocabularies: ", dataset.vocab)
print(f"x index sample: \n{dataset.idx2str(dataset.x[0])}\n{dataset.x[0]}",
f"\ny index sample: \n{dataset.idx2str(dataset.y[0])}\n{dataset.y[0]}")
loader = DataLoader(dataset, batch_size=32, shuffle=True)
"""
Chinese time order: yy/mm/dd ['31-04-26', '04-07-18', '33-06-06']
English time order: dd/M/yyyy ['26/Apr/2031', '18/Jul/2004', '06/Jun/2033']
Vocabularies: {'<PAD>', '8', '<GO>', 'Oct', 'Jun', 'Nov', 'Feb', 'Apr', '2', 'Mar', '-', 'Jan', '1', '3', '6', 'May', '<EOS>', '/', 'Sep', '9', '7', '4', 'Jul', '0', '5', 'Dec', 'Aug'}
x index sample:
31-04-26
[6 4 1 3 7 1 5 9]
y index sample:
<GO>26/Apr/2031<EOS>
[14 5 9 2 15 2 5 3 6 4 13]
"""
model = Seq2Seq(dataset.num_word, dataset.num_word, emb_dim=16, units=32, max_pred_len=11,
start_token=dataset.start_token, end_token=dataset.end_token)
for i in range(100):
for batch_idx, batch in enumerate(loader):
bx, by, decoder_len = batch # x[32,8] y[32,11]
loss = model.step(bx, by)
if batch_idx % 70 == 0:
target = dataset.idx2str(by[0, 1:-1].data.numpy())
pred = model.inference(bx[0:1])
res = dataset.idx2str(pred[0].data.numpy())
src = dataset.idx2str(bx[0].data.numpy())
print(
"Epoch: ", i,
"| t: ", batch_idx,
"| loss: %.3f" % loss,
"| input: ", src,
"| target: ", target,
"| inference: ", res,
)
"""
Epoch: 0 | t: 0 | loss: 3.303 | input: 16-01-18 | target: 18/Jan/2016 | inference: 11111111111
Epoch: 0 | t: 70 | loss: 2.342 | input: 04-01-31 | target: 31/Jan/2004 | inference: /2000<EOS>
Epoch: 1 | t: 0 | loss: 1.721 | input: 13-01-15 | target: 15/Jan/2013 | inference: 1///200<EOS>
Epoch: 1 | t: 70 | loss: 1.310 | input: 97-09-16 | target: 16/Sep/1997 | inference: 1///2000<EOS>
Epoch: 2 | t: 0 | loss: 1.199 | input: 25-05-02 | target: 02/May/2025 | inference: 20///2000<EOS>
Epoch: 2 | t: 70 | loss: 1.154 | input: 76-09-14 | target: 14/Sep/1976 | inference: 20/Mar/2020<EOS>
Epoch: 3 | t: 0 | loss: 1.102 | input: 79-10-16 | target: 16/Oct/1979 | inference: 11/Dec/2019<EOS>
Epoch: 3 | t: 70 | loss: 1.055 | input: 08-06-22 | target: 22/Jun/2008 | inference: 22/Mar/2000<EOS>
Epoch: 4 | t: 0 | loss: 1.010 | input: 01-01-12 | target: 12/Jan/2001 | inference: 11/May/2019<EOS>
Epoch: 4 | t: 70 | loss: 0.960 | input: 20-02-18 | target: 18/Feb/2020 | inference: 20/Mar/2000<EOS>
Epoch: 5 | t: 0 | loss: 0.904 | input: 06-12-20 | target: 20/Dec/2006 | inference: 02/Mar/2018<EOS>
Epoch: 5 | t: 70 | loss: 0.840 | input: 85-02-08 | target: 08/Feb/1985 | inference: 07/Oct/1988<EOS>
"""