ESIM
整体结构:
输入(input embedding)
利用预训练的词向量或者添加embeddin层
采用双向LSTM对输入做特征提取,并将其隐藏值保存下来
局部推理建模层
句子间注意力计算:
首先,基于点积函数的注意力打分矩阵:
获取 a_title(b_bar的加权向量)和b_title(a_bar的加权向量)
矩阵拼接
目的:为了获取更多的特征
即得到encoding值与加权encoding值之后,下一步是分别对这两个值做差异性计算,作者认为这样的操作有助于强化每个单词的表征,方式:对位相减和对位相乘。
推理组合层
将
m
a
m_a
ma和
m
b
m_b
mb送入双向LSTM得到
V
a
V_a
Va和
V
b
V_b
Vb,对
V
a
V_a
Va和
V
b
V_b
Vb分别使用平均和最大池化得到
V
a
a
v
e
V_aave
Vaave、
V
a
m
a
x
V_amax
Vamax和
V
b
a
v
e
V_bave
Vbave、
V
b
m
a
x
V_bmax
Vbmax,将这四个向量拼接成V。
输出预测
利用全连接层+tanh激活函数+softmax函数进行分类
代码
import torch
import torch.nn as nn
class Esim(nn.Module):
def __init__(self, config, TEXT):
super(Esim, self).__init__()
self.embed = nn.Embedding(*TEXT.vocab.vectors.size())
self.embed.weight.data.copy_(TEXT.vocab.vectors)
self.a_lstm = nn.LSTM(input_size=TEXT.vocab.vectors.size()[1],
hidden_size=config.hidden_size,
batch_first=True,
bidirectional=True,
num_layers=2,
dropout=0.2)
self.b_lstm = nn.LSTM(input_size=TEXT.vocab.vectors.size()[1],
hidden_size=config.hidden_size,
batch_first=True,
bidirectional=True,
num_layers=2,
dropout=0.2)
self.a_lstm_infer = nn.LSTM(8 * config.hidden_size,
config.hidden_size,
batch_first=True,
bidirectional=True)
self.b_lstm_infer = nn.LSTM(8 * config.hidden_size,
config.hidden_size,
bidirectional=True,
batch_first=True)
self.liner = nn.Sequential(nn.Linear(8 * config.hidden_size, 2 * config.hidden_size),
nn.Tanh(),
nn.Linear(2 * config.hidden_size, config.linear_size),
nn.Tanh(),
nn.Linear(config.linear_size, config.classes))
def forward(self, a, b):
emb_a = self.embed(a)
emb_b = self.embed(b)
a_bar, _ = self.a_lstm(emb_a)
b_bar, _ = self.b_lstm(emb_b)
e = torch.matmul(a_bar, b_bar.premute(0, 2, 1))
a_title = torch.matmul(torch.softmax(e, dim=1).premute(0, 2, 1), b_bar)
b_title = torch.matmul(torch.softmax(e, dim=1).prenute(0, 2, 1), a_bar)
ma = torch.cat([a_bar, a_title, a_bar - a_title, a_bar * a_title], dim=2)
mb = torch.cat([b_bar, b_title, b_bar - b_title, b_bar * b_title], dim=2)
va, _ = self.a_lstm_infer(ma)
vb, _ = self.b_lstm_infer(mb)
va_avg = torch.mean(va, dim=1)
va_max = torch.max(va, dim=1)[0]
vb_avg = torch.mean(vb, dim=1)
vb_max = torch.max(vb, dim=1)[0]
v = torch.cat([va_avg, va_max, vb_avg, vb_max], dim=1)
out = self.liner(v)
return out