在上一篇博客中介绍了ESIM模型(https://blog.csdn.net/zhang2010hao/article/details/87913910),这里介绍一个新的文本匹配模型BiMPM,其在某些任务中的效果超过ESIM模型。
论文链接:http://tongtianta.site/paper/1759
现在去判断两个句子相似性的深度学校解决方案主要有两种,其一是Simaese network(ABCNN、SiaGRU等),这类模型对两个输入句子通过相同的共享权重的神经网络结构得到两个句子向量,然后对这两个句子向量做匹配。这种共享参数的方式可以有效减少学习的参数,让训练更方便,这种方式存在一定的问题,对于两个句子之间的交互信息利用的很少;第二种方法 matching-aggregation (比如 ESIM 和 BiMPM): 这种方法首先对两个句子之间的单元做匹配 (比如各自经过 LSTM 处理后得到的不同 time step 的输出),匹配结果通过一个神经网络 (CNN或LSTM) 转化为一个向量,然后再做匹配。这种方式可以捕捉到两个句子之间的交互特征。
BiMPM主要的改进是从多个角度匹配模型,解决之前交互模型匹配不充分的问题。
一、原理
BIMPM的结构图如图1所示。

设输入的两个句子分别为、
,BIMPM模型的实现可以分为五步:
1)句子表示层,首先使用预训练的词向量将、
分别转化为对应的词向量序列
和
,其中
和
分别表示
、
的长度。
2)文本表示层,基于双向LSTM生成、
对应的隐状态序列
和
,
为连接问题
中第
位置单词对应的正向和反向隐状态向量得到:
(1)
3)交互信息层
基于和
计算
、
之间的信息交互,并生成
、
交互之后的向量标识序列
和
,下面具体说明怎么计算交互向量。
论文中采用四种方式完成、
之间的信息交互,这里以
为例进行说明,
原理相同。
(2)
定义多通道余弦匹配函数,用于计算给定向量
和
之间的相似度:
(3)
其中,和
是两个
维向量,
是待学习的参数矩阵。四种交互方式如下:

a.全匹配(full-matching)交互。该交互方式生成和
:
(4)
(5)
b.最大池化匹配(max-pooling-matching)交互。该交互生成和
:
(6)
(7)
c.注意匹配(attentive-matching)交互。该交互方式生成:
(8)
(9)
(10)
(11)
d.最大注意匹配(max-attentive-matching)交互。该交互方式生成:
(12)
(13)
上式中,表示
中与
具有最大余弦距离的向量,反向的类似。
4)聚合层
聚合层的主要功能是聚合两个匹配向量序列为一个固定长度的匹配向量。对两个匹配序列分别使用 Bi-LSTM,然后连接 Bi-LSTM 最后一个 time-step 的向量(4个)得到最后的匹配向量。
5)预测层
使用全连接层实现分类处理
6)小结
BiMPM与ESIM有异曲同工之处,都有匹配层和融合层,也都使用BiLSTM,主要区别就是match的方式不一样,BiMPM利用了更多角度的信息。由于参数增加,BiMPM训练较慢。
二、pytorch实现
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
class BiMPM(nn.Module):
def __init__(self, vocab, word_dim, char_dim, hidden_size, w_size, n_labels, embs=None, char_hidden_size=None,
use_char_emb=True, max_word_len=None, char_vocab_size=None, dropout=0):
super(BiMPM, self).__init__()
self.dropout = dropout
self.use_char_emb = use_char_emb
self.hidden_size = hidden_size
self.char_hidden_size = char_hidden_size
self.max_word_len = max_word_len
self.d = word_dim
self.l = w_size
# ----- Word Representation Layer -----
if use_char_emb:
if not char_hidden_size or not max_word_len or not char_vocab_size:
raise ValueError("if use char embeddings, char_hidden_size, max_word_len, char_vocab_size must not be None")
self.d = self.d + char_hidden_size
self.char_emb = nn.Embedding(char_vocab_size, char_dim, padding_idx=0)
self.char_LSTM = nn.LSTM(
input_size=char_dim,
hidden_size=char_hidden_size,
num_layers=1,
bidirectional=False,
batch_first=True)
self.word_emb = nn.Embedding(len(vocab), word_dim)
# initialize word embedding with GloVe or Other pre-trained word embedding
if embs:
embvecs, embwords = embs
self.word_emb.weight.data.copy_(torch.from_numpy(np.asarray(embvecs)))
# no fine-tuning for word vectors
self.word_emb.weight.requires_grad = False
# ----- Context Representation Layer -----
self.context_LSTM = nn.LSTM(
input_size=self.d,
hidden_size=hidden_size,
num_layers=1,
bidirectional=True,
batch_first=True
)
# ----- Matching Layer -----
for i in range(1, 9):
setattr(self, f'mp_w{i}',
nn.Parameter(torch.rand(self.l, self.hidden_size)))
# ----- Aggregation Layer -----
self.aggregation_LSTM = nn.LSTM(
input_size=self.l * 8,
hidden_size=hidden_size,
num_layers=1,
bidirectional=True,
batch_first=True
)
# ----- Prediction Layer -----
self.pred_fc1 = nn.Linear(hidden_size * 4, hidden_size * 2)
self.pred_fc2 = nn.Linear(hidden_size * 2, n_labels)
self.reset_parameters()
def reset_parameters(self):
# ----- Word Representation Layer -----
if self.use_char_emb:
nn.init.uniform_(self.char_emb.weight, -0.005, 0.005)
# zero vectors for padding
self.char_emb.weight.data[0].fill_(0)
nn.init.kaiming_normal_(self.char_LSTM.weight_ih_l0)
nn.init.constant_(self.char_LSTM.bias_ih_l0, val=0)
nn.init.orthogonal_(self.char_LSTM.weight_hh_l0)
nn.init.constant_(self.char_LSTM.bias_hh_l0, val=0)
# <unk> vectors is randomly initialized
nn.init.uniform_(self.word_emb.weight.data[0], -0.1, 0.1)
# ----- Context Representation Layer -----
nn.init.kaiming_normal_(self.context_LSTM.weight_ih_l0)
nn.init.constant_(self.context_LSTM.bias_ih_l0, val=0)
nn.init.orthogonal_(self.context_LSTM.weight_hh_l0)
nn.init.constant_(self.context_LSTM.bias_hh_l0, val=0)
nn.init.kaiming_normal_(self.context_LSTM.weight_ih_l0_reverse)
nn.init.constant_(self.context_LSTM.bias_ih_l0_reverse, val=0)
nn.init.orthogonal_(self.context_LSTM.weight_hh_l0_reverse)
nn.init.constant_(self.context_LSTM.bias_hh_l0_reverse, val=0)
# ----- Matching Layer -----
for i in range(1, 9):
w = getattr(self, f'mp_w{i}')
nn.init.kaiming_normal_(w)
# ----- Aggregation Layer -----
nn.init.kaiming_normal_(self.aggregation_LSTM.weight_ih_l0)
nn.init.constant_(self.aggregation_LSTM.bias_ih_l0, val=0)
nn.init.orthogonal_(self.aggregation_LSTM.weight_hh_l0)
nn.init.constant_(self.aggregation_LSTM.bias_hh_l0, val=0)
nn.init.kaiming_normal_(self.aggregation_LSTM.weight_ih_l0_reverse)
nn.init.constant_(self.aggregation_LSTM.bias_ih_l0_reverse, val=0)
nn.init.orthogonal_(self.aggregation_LSTM.weight_hh_l0_reverse)
nn.init.constant_(self.aggregation_LSTM.bias_hh_l0_reverse, val=0)
# ----- Prediction Layer ----
nn.init.uniform_(self.pred_fc1.weight, -0.005, 0.005)
nn.init.constant_(self.pred_fc1.bias, val=0)
nn.init.uniform_(self.pred_fc2.weight, -0.005, 0.005)
nn.init.constant_(self.pred_fc2.bias, val=0)
def dropout_fun(self, v):
return F.dropout(v, p=self.dropout, training=self.training)
def forward(self, q1_inputs, q2_inputs, q1_char_inputs=None, q2_char_inputs=None, q1_lens=None, q2_lens=None):
# ----- Matching Layer -----
def mp_matching_func(v1, v2, w):
"""
:param v1: (batch, seq_len, hidden_size)
:param v2: (batch, seq_len, hidden_size) or (batch, hidden_size)
:param w: (l, hidden_size)
:return: (batch, l)
"""
seq_len = v1.size(1)
# Trick for large memory requirement
"""
if len(v2.size()) == 2:
v2 = torch.stack([v2] * seq_len, dim=1)
m = []
for i in range(self.l):
# v1: (batch, seq_len, hidden_size)
# v2: (batch, seq_len, hidden_size)
# w: (1, 1, hidden_size)
# -> (batch, seq_len)
m.append(F.cosine_similarity(w[i].view(1, 1, -1) * v1, w[i].view(1, 1, -1) * v2, dim=2))
# list of (batch, seq_len) -> (batch, seq_len, l)
m = torch.stack(m, dim=2)
"""
# (1, 1, hidden_size, l)
w = w.transpose(1, 0).unsqueeze(0).unsqueeze(0)
# (batch, seq_len, hidden_size, l)
v1 = w * torch.stack([v1] * self.l, dim=3)
if len(v2.size()) == 3:
v2 = w * torch.stack([v2] * self.l, dim=3)
else:
v2 = w * torch.stack([torch.stack([v2] * seq_len, dim=1)] * self.l, dim=3)
m = F.cosine_similarity(v1, v2, dim=2)
return m
def mp_matching_func_pairwise(v1, v2, w):
"""
:param v1: (batch, seq_len1, hidden_size)
:param v2: (batch, seq_len2, hidden_size)
:param w: (l, hidden_size)
:return: (batch, l, seq_len1, seq_len2)
"""
# Trick for large memory requirement
"""
m = []
for i in range(self.l):
# (1, 1, hidden_size)
w_i = w[i].view(1, 1, -1)
# (batch, seq_len1, hidden_size), (batch, seq_len2, hidden_size)
v1, v2 = w_i * v1, w_i * v2
# (batch, seq_len, hidden_size->1)
v1_norm = v1.norm(p=2, dim=2, keepdim=True)
v2_norm = v2.norm(p=2, dim=2, keepdim=True)
# (batch, seq_len1, seq_len2)
n = torch.matmul(v1, v2.permute(0, 2, 1))
d = v1_norm * v2_norm.permute(0, 2, 1)
m.append(div_with_small_value(n, d))
# list of (batch, seq_len1, seq_len2) -> (batch, seq_len1, seq_len2, l)
m = torch.stack(m, dim=3)
"""
# (1, l, 1, hidden_size)
w = w.unsqueeze(0).unsqueeze(2)
# (batch, l, seq_len, hidden_size)
v1, v2 = w * torch.stack([v1] * self.l, dim=1), w * torch.stack([v2] * self.l, dim=1)
# (batch, l, seq_len, hidden_size->1)
v1_norm = v1.norm(p=2, dim=3, keepdim=True)
v2_norm = v2.norm(p=2, dim=3, keepdim=True)
# (batch, l, seq_len1, seq_len2)
n = torch.matmul(v1, v2.transpose(2, 3))
d = v1_norm * v2_norm.transpose(2, 3)
# (batch, seq_len1, seq_len2, l)
m = div_with_small_value(n, d).permute(0, 2, 3, 1)
return m
def attention(v1, v2):
"""
:param v1: (batch, seq_len1, hidden_size)
:param v2: (batch, seq_len2, hidden_size)
:return: (batch, seq_len1, seq_len2)
"""
# (batch, seq_len1, 1)
v1_norm = v1.norm(p=2, dim=2, keepdim=True)
# (batch, 1, seq_len2)
v2_norm = v2.norm(p=2, dim=2, keepdim=True).permute(0, 2, 1)
# (batch, seq_len1, seq_len2)
a = torch.bmm(v1, v2.permute(0, 2, 1))
d = v1_norm * v2_norm
return div_with_small_value(a, d)
def div_with_small_value(n, d, eps=1e-8):
# too small values are replaced by 1e-8 to prevent it from exploding.
d = d * (d > eps).float() + eps * (d <= eps).float()
return n / d
# ----- Word Representation Layer -----
# (batch, seq_len) -> (batch, seq_len, word_dim)
q1_lens_tmp, q1_indices = torch.sort(q1_lens, descending=True)
q1_input_tmp = q1_inputs[q1_indices]
q2_lens_tmp, q2_indices = torch.sort(q2_lens, descending=True)
q2_input_tmp = q2_inputs[q2_indices]
p = self.word_emb(q1_input_tmp)
h = self.word_emb(q2_input_tmp)
if self.use_char_emb:
# (batch, seq_len, max_word_len) -> (batch * seq_len, max_word_len)
seq_len_p = q1_char_inputs.size(1)
seq_len_h = q2_char_inputs.size(1)
char_p = q1_char_inputs.view(-1, self.max_word_len)
char_h = q2_char_inputs.view(-1, self.max_word_len)
# (batch * seq_len, max_word_len, char_dim)-> (1, batch * seq_len, char_hidden_size)
_, (char_p, _) = self.char_LSTM(self.char_emb(char_p))
_, (char_h, _) = self.char_LSTM(self.char_emb(char_h))
# (batch, seq_len, char_hidden_size)
char_p = char_p.view(-1, seq_len_p, self.char_hidden_size)
char_h = char_h.view(-1, seq_len_h, self.char_hidden_size)
# (batch, seq_len, word_dim + char_hidden_size)
p = torch.cat([p, char_p], dim=-1)
h = torch.cat([h, char_h], dim=-1)
p = self.dropout_fun(p)
h = self.dropout_fun(h)
# ----- Context Representation Layer -----
# (batch, seq_len, hidden_size * 2)
p_packed = pack_padded_sequence(p, q1_lens_tmp, batch_first=True)
p_out, _ = self.context_LSTM(p_packed)
con_p, _ = pad_packed_sequence(p_out, batch_first=True)
h_packed = pack_padded_sequence(h, q2_lens_tmp, batch_first=True)
h_out, _ = self.context_LSTM(h_packed)
con_h, _ = pad_packed_sequence(h_out, batch_first=True)
_, q1_desorte_indices = torch.sort(q1_indices, descending=False)
con_p = con_p[q1_desorte_indices]
_, q2_desorte_indices = torch.sort(q2_indices, descending=False)
con_h = con_h[q2_desorte_indices]
con_p = self.dropout_fun(con_p)
con_h = self.dropout_fun(con_h)
# (batch, seq_len, hidden_size)
con_p_fw, con_p_bw = torch.split(con_p, self.hidden_size, dim=-1)
con_h_fw, con_h_bw = torch.split(con_h, self.hidden_size, dim=-1)
# 1. Full-Matching
# (batch, seq_len, hidden_size), (batch, hidden_size)
# -> (batch, seq_len, l)
mv_p_full_fw = mp_matching_func(con_p_fw, con_h_fw[:, -1, :], self.mp_w1)
mv_p_full_bw = mp_matching_func(con_p_bw, con_h_bw[:, 0, :], self.mp_w2)
mv_h_full_fw = mp_matching_func(con_h_fw, con_p_fw[:, -1, :], self.mp_w1)
mv_h_full_bw = mp_matching_func(con_h_bw, con_p_bw[:, 0, :], self.mp_w2)
# 2. Maxpooling-Matching
# (batch, seq_len1, seq_len2, l)
mv_max_fw = mp_matching_func_pairwise(con_p_fw, con_h_fw, self.mp_w3)
mv_max_bw = mp_matching_func_pairwise(con_p_bw, con_h_bw, self.mp_w4)
# (batch, seq_len, l)
mv_p_max_fw, _ = mv_max_fw.max(dim=2)
mv_p_max_bw, _ = mv_max_bw.max(dim=2)
mv_h_max_fw, _ = mv_max_fw.max(dim=1)
mv_h_max_bw, _ = mv_max_bw.max(dim=1)
# 3. Attentive-Matching
# (batch, seq_len1, seq_len2)
att_fw = attention(con_p_fw, con_h_fw)
att_bw = attention(con_p_bw, con_h_bw)
# (batch, seq_len2, hidden_size) -> (batch, 1, seq_len2, hidden_size)
# (batch, seq_len1, seq_len2) -> (batch, seq_len1, seq_len2, 1)
# -> (batch, seq_len1, seq_len2, hidden_size)
att_h_fw = con_h_fw.unsqueeze(1) * att_fw.unsqueeze(3)
att_h_bw = con_h_bw.unsqueeze(1) * att_bw.unsqueeze(3)
# (batch, seq_len1, hidden_size) -> (batch, seq_len1, 1, hidden_size)
# (batch, seq_len1, seq_len2) -> (batch, seq_len1, seq_len2, 1)
# -> (batch, seq_len1, seq_len2, hidden_size)
att_p_fw = con_p_fw.unsqueeze(2) * att_fw.unsqueeze(3)
att_p_bw = con_p_bw.unsqueeze(2) * att_bw.unsqueeze(3)
# (batch, seq_len1, hidden_size) / (batch, seq_len1, 1) -> (batch, seq_len1, hidden_size)
att_mean_h_fw = div_with_small_value(att_h_fw.sum(dim=2), att_fw.sum(dim=2, keepdim=True))
att_mean_h_bw = div_with_small_value(att_h_bw.sum(dim=2), att_bw.sum(dim=2, keepdim=True))
# (batch, seq_len2, hidden_size) / (batch, seq_len2, 1) -> (batch, seq_len2, hidden_size)
att_mean_p_fw = div_with_small_value(att_p_fw.sum(dim=1), att_fw.sum(dim=1, keepdim=True).permute(0, 2, 1))
att_mean_p_bw = div_with_small_value(att_p_bw.sum(dim=1), att_bw.sum(dim=1, keepdim=True).permute(0, 2, 1))
# (batch, seq_len, l)
mv_p_att_mean_fw = mp_matching_func(con_p_fw, att_mean_h_fw, self.mp_w5)
mv_p_att_mean_bw = mp_matching_func(con_p_bw, att_mean_h_bw, self.mp_w6)
mv_h_att_mean_fw = mp_matching_func(con_h_fw, att_mean_p_fw, self.mp_w5)
mv_h_att_mean_bw = mp_matching_func(con_h_bw, att_mean_p_bw, self.mp_w6)
# 4. Max-Attentive-Matching
# (batch, seq_len1, hidden_size)
att_max_h_fw, _ = att_h_fw.max(dim=2)
att_max_h_bw, _ = att_h_bw.max(dim=2)
# (batch, seq_len2, hidden_size)
att_max_p_fw, _ = att_p_fw.max(dim=1)
att_max_p_bw, _ = att_p_bw.max(dim=1)
# (batch, seq_len, l)
mv_p_att_max_fw = mp_matching_func(con_p_fw, att_max_h_fw, self.mp_w7)
mv_p_att_max_bw = mp_matching_func(con_p_bw, att_max_h_bw, self.mp_w8)
mv_h_att_max_fw = mp_matching_func(con_h_fw, att_max_p_fw, self.mp_w7)
mv_h_att_max_bw = mp_matching_func(con_h_bw, att_max_p_bw, self.mp_w8)
# (batch, seq_len, l * 8)
mv_p = torch.cat(
[mv_p_full_fw, mv_p_max_fw, mv_p_att_mean_fw, mv_p_att_max_fw,
mv_p_full_bw, mv_p_max_bw, mv_p_att_mean_bw, mv_p_att_max_bw], dim=2)
mv_h = torch.cat(
[mv_h_full_fw, mv_h_max_fw, mv_h_att_mean_fw, mv_h_att_max_fw,
mv_h_full_bw, mv_h_max_bw, mv_h_att_mean_bw, mv_h_att_max_bw], dim=2)
mv_p = self.dropout_fun(mv_p)
mv_h = self.dropout_fun(mv_h)
# ----- Aggregation Layer -----
# (batch, seq_len, l * 8) -> (2, batch, hidden_size)
_, (agg_p_last, _) = self.aggregation_LSTM(mv_p)
_, (agg_h_last, _) = self.aggregation_LSTM(mv_h)
# 2 * (2, batch, hidden_size) -> 2 * (batch, hidden_size * 2) -> (batch, hidden_size * 4)
x = torch.cat(
[agg_p_last.permute(1, 0, 2).contiguous().view(-1, self.hidden_size * 2),
agg_h_last.permute(1, 0, 2).contiguous().view(-1, self.hidden_size * 2)], dim=1)
x = self.dropout_fun(x)
# ----- Prediction Layer -----
x = F.tanh(self.pred_fc1(x))
x = self.dropout_fun(x)
x = self.pred_fc2(x)
return x