文本匹配模型-BiMPM

15 篇文章 0 订阅
4 篇文章 1 订阅

在上一篇博客中介绍了ESIM模型(https://blog.csdn.net/zhang2010hao/article/details/87913910),这里介绍一个新的文本匹配模型BiMPM,其在某些任务中的效果超过ESIM模型。

论文链接:http://tongtianta.site/paper/1759

现在去判断两个句子相似性的深度学校解决方案主要有两种,其一是Simaese network(ABCNN、SiaGRU等),这类模型对两个输入句子通过相同的共享权重的神经网络结构得到两个句子向量,然后对这两个句子向量做匹配。这种共享参数的方式可以有效减少学习的参数,让训练更方便,这种方式存在一定的问题,对于两个句子之间的交互信息利用的很少;第二种方法 matching-aggregation (比如 ESIM 和 BiMPM): 这种方法首先对两个句子之间的单元做匹配 (比如各自经过 LSTM 处理后得到的不同 time step 的输出),匹配结果通过一个神经网络 (CNN或LSTM) 转化为一个向量,然后再做匹配。这种方式可以捕捉到两个句子之间的交互特征。

BiMPM主要的改进是从多个角度匹配模型,解决之前交互模型匹配不充分的问题。

一、原理

BIMPM的结构图如图1所示。

图1 BIMPM结构图

设输入的两个句子分别为Q_1Q_2,BIMPM模型的实现可以分为五步:

1)句子表示层,首先使用预训练的词向量将Q_1Q_2分别转化为对应的词向量序列[p_1,p_2,...,p_M][q_1, q_2,...,q_N],其中MN分别表示Q_1Q_2的长度。

2)文本表示层,基于双向LSTM生成Q_1Q_2对应的隐状态序列(h_1^p,...,h_M^p)(h_1^q,...,h_N^q)h_i^x为连接问题X中第i位置单词对应的正向和反向隐状态向量得到:

                                                                                  h_i^X = \begin{bmatrix} \overrightarrow{h}_i^X\\ \overleftarrow{h}_i^X \end{bmatrix}                   (1)

3)交互信息层

基于(h_1^p,...,h_M^p)(h_1^q,...,h_N^q)计算Q_1Q_2之间的信息交互,并生成Q_1Q_2交互之后的向量标识序列(m_1^p,...,m_M^p)(m_1^q,...,m_N^q),下面具体说明怎么计算交互向量。

论文中采用四种方式完成Q_1Q_2之间的信息交互,这里以m_i^p为例进行说明,m_i^q原理相同。

                                                                             m_i^p=\begin{bmatrix} \overrightarrow{m}_i^{full}\\ \overleftarrow{m}_i^{full}\\ \overrightarrow{m}_i^{max}\\ \overleftarrow{m}_i^{max}\\ \overrightarrow{m}_i^{all}\\ \overleftarrow{m}_i^{all}\\ \overrightarrow{m}_i^{maxatt}\\ \overleftarrow{m}_i^{maxatt}\\ \end{bmatrix}                 (2)

定义多通道余弦匹配函数f_m(\cdot ),用于计算给定向量v_1v_2之间的相似度:

                                           m=f_m(v_1,v_2;W)=\begin{bmatrix} m_1\\ ...\\ m_k\\ ...\\ m_l \end{bmatrix}=\begin{bmatrix} cosine(W_1*v_1,W_1*v_2)\\ ...\\ cosine(W_k*v_1,W_k*v_2)\\ ...\\ cosine(W_l*v_1,W_l*v_2) \end{bmatrix}\in \mathbb{R}^l    (3)

其中,v_1\in\mathbb{R}^dv_2\in\mathbb{R}^d是两个d维向量,W\in\mathbb{R}^{l\times d}是待学习的参数矩阵。四种交互方式如下:

图2 四种交互方式

a.全匹配(full-matching)交互。该交互方式生成\overrightarrow{m}_i^{full}\overleftarrow{m}_i^{full}:

                                                                            \overrightarrow{m}_i^{full}=f_m(\overrightarrow{h}_i^p,\overrightarrow{h}_N^q;W^1)                  (4)

                                                                            \overleftarrow{m}_i^{full}=f_m(\overleftarrow{h}_i^p,\overleftarrow{h}_1^q;W^2)                    (5)

b.最大池化匹配(max-pooling-matching)交互。该交互生成\overrightarrow{m}_i^{max}\overleftarrow{m}_i^{max}

                                                                \overrightarrow{m}_i^{max}=max_{j\in(1,...,N)}f_m(\overrightarrow{h}_i^p,\overrightarrow{h}_j^q;W^3)          (6)

                                                                 \overleftarrow{m}_i^{max}=max_{j\in(1,...,N)}f_m(\overleftarrow{h}_i^p,\overleftarrow{h}_j^q;W^4)          (7)

c.注意匹配(attentive-matching)交互。该交互方式生成:

                                               \overrightarrow{m}_i^{att}=f_m(\overrightarrow{h}_i^{p},\overrightarrow{h}_i^{mean};W^5)=f_m\begin{pmatrix} \overrightarrow{h}_i^p,\frac{\sum_{j=1}^N \overrightarrow{\alpha }_{i,j} \cdot \overrightarrow{h}_j^q }{\sum_{j=1}^N \overrightarrow{\alpha }_{i,j} };W^5 \end{pmatrix}       (8)

                                               \overleftarrow{m}_i^{att}=f_m(\overleftarrow{h}_i^{p},\overleftarrow{h}_i^{mean};W^6)=f_m\begin{pmatrix} \overleftarrow{h}_i^p,\frac{\sum_{j=1}^N \overleftarrow{\alpha }_{i,j} \cdot \overleftarrow{h}_j^q }{\sum_{j=1}^N \overleftarrow{\alpha }_{i,j} };W^6 \end{pmatrix}        (9)

                                                                            \overrightarrow{\alpha}_{i,j}=cosine(\overrightarrow{h}_i^p,\overrightarrow{h}_j^q)                         (10)

                                                                            \overleftarrow{\alpha}_{i,j}=cosine(\overleftarrow{h}_i^p,\overleftarrow{h}_j^q)                         (11)

d.最大注意匹配(max-attentive-matching)交互。该交互方式生成:

                                                                     \overrightarrow{m}_i^{maxatt}=f_m(\overrightarrow{h}_i,\overrightarrow{h}_{max}^q;W^7)                     (12)

                                                                     \overleftarrow{m}_i^{maxatt}=f_m(\overleftarrow{h}_i,\overleftarrow{h}_{max}^q;W^8)                      (13)

上式中,\overrightarrow{h}_{max}^q表示(\overrightarrow{h}_1^q,...,\overrightarrow{h}_N^q)中与\overrightarrow{h}_i^p具有最大余弦距离的向量,反向的类似。

4)聚合层

聚合层的主要功能是聚合两个匹配向量序列为一个固定长度的匹配向量。对两个匹配序列分别使用 Bi-LSTM,然后连接 Bi-LSTM 最后一个 time-step 的向量(4个)得到最后的匹配向量。

5)预测层

使用全连接层实现分类处理

6)小结

BiMPM与ESIM有异曲同工之处,都有匹配层和融合层,也都使用BiLSTM,主要区别就是match的方式不一样,BiMPM利用了更多角度的信息。由于参数增加,BiMPM训练较慢。

二、pytorch实现

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence


class BiMPM(nn.Module):
    def __init__(self, vocab, word_dim, char_dim, hidden_size, w_size, n_labels, embs=None, char_hidden_size=None,
                 use_char_emb=True, max_word_len=None, char_vocab_size=None, dropout=0):
        super(BiMPM, self).__init__()
        self.dropout = dropout
        self.use_char_emb = use_char_emb
        self.hidden_size = hidden_size
        self.char_hidden_size = char_hidden_size
        self.max_word_len = max_word_len

        self.d = word_dim
        self.l = w_size

        # ----- Word Representation Layer -----
        if use_char_emb:
            if not char_hidden_size or not max_word_len or not char_vocab_size:
                raise ValueError("if use char embeddings, char_hidden_size, max_word_len, char_vocab_size must not be None")

            self.d = self.d + char_hidden_size
            self.char_emb = nn.Embedding(char_vocab_size, char_dim, padding_idx=0)
            self.char_LSTM = nn.LSTM(
                input_size=char_dim,
                hidden_size=char_hidden_size,
                num_layers=1,
                bidirectional=False,
                batch_first=True)

        self.word_emb = nn.Embedding(len(vocab), word_dim)
        # initialize word embedding with GloVe or Other pre-trained word embedding
        if embs:
            embvecs, embwords = embs
            self.word_emb.weight.data.copy_(torch.from_numpy(np.asarray(embvecs)))
        # no fine-tuning for word vectors
        self.word_emb.weight.requires_grad = False

        # ----- Context Representation Layer -----
        self.context_LSTM = nn.LSTM(
            input_size=self.d,
            hidden_size=hidden_size,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )

        # ----- Matching Layer -----
        for i in range(1, 9):
            setattr(self, f'mp_w{i}',
                    nn.Parameter(torch.rand(self.l, self.hidden_size)))

        # ----- Aggregation Layer -----
        self.aggregation_LSTM = nn.LSTM(
            input_size=self.l * 8,
            hidden_size=hidden_size,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )

        # ----- Prediction Layer -----
        self.pred_fc1 = nn.Linear(hidden_size * 4, hidden_size * 2)
        self.pred_fc2 = nn.Linear(hidden_size * 2, n_labels)

        self.reset_parameters()

    def reset_parameters(self):
        # ----- Word Representation Layer -----
        if self.use_char_emb:
            nn.init.uniform_(self.char_emb.weight, -0.005, 0.005)
            # zero vectors for padding
            self.char_emb.weight.data[0].fill_(0)

            nn.init.kaiming_normal_(self.char_LSTM.weight_ih_l0)
            nn.init.constant_(self.char_LSTM.bias_ih_l0, val=0)
            nn.init.orthogonal_(self.char_LSTM.weight_hh_l0)
            nn.init.constant_(self.char_LSTM.bias_hh_l0, val=0)

        # <unk> vectors is randomly initialized
        nn.init.uniform_(self.word_emb.weight.data[0], -0.1, 0.1)
        # ----- Context Representation Layer -----
        nn.init.kaiming_normal_(self.context_LSTM.weight_ih_l0)
        nn.init.constant_(self.context_LSTM.bias_ih_l0, val=0)
        nn.init.orthogonal_(self.context_LSTM.weight_hh_l0)
        nn.init.constant_(self.context_LSTM.bias_hh_l0, val=0)

        nn.init.kaiming_normal_(self.context_LSTM.weight_ih_l0_reverse)
        nn.init.constant_(self.context_LSTM.bias_ih_l0_reverse, val=0)
        nn.init.orthogonal_(self.context_LSTM.weight_hh_l0_reverse)
        nn.init.constant_(self.context_LSTM.bias_hh_l0_reverse, val=0)

        # ----- Matching Layer -----
        for i in range(1, 9):
            w = getattr(self, f'mp_w{i}')
            nn.init.kaiming_normal_(w)

        # ----- Aggregation Layer -----
        nn.init.kaiming_normal_(self.aggregation_LSTM.weight_ih_l0)
        nn.init.constant_(self.aggregation_LSTM.bias_ih_l0, val=0)
        nn.init.orthogonal_(self.aggregation_LSTM.weight_hh_l0)
        nn.init.constant_(self.aggregation_LSTM.bias_hh_l0, val=0)

        nn.init.kaiming_normal_(self.aggregation_LSTM.weight_ih_l0_reverse)
        nn.init.constant_(self.aggregation_LSTM.bias_ih_l0_reverse, val=0)
        nn.init.orthogonal_(self.aggregation_LSTM.weight_hh_l0_reverse)
        nn.init.constant_(self.aggregation_LSTM.bias_hh_l0_reverse, val=0)

        # ----- Prediction Layer ----
        nn.init.uniform_(self.pred_fc1.weight, -0.005, 0.005)
        nn.init.constant_(self.pred_fc1.bias, val=0)

        nn.init.uniform_(self.pred_fc2.weight, -0.005, 0.005)
        nn.init.constant_(self.pred_fc2.bias, val=0)

    def dropout_fun(self, v):
        return F.dropout(v, p=self.dropout, training=self.training)

    def forward(self, q1_inputs, q2_inputs, q1_char_inputs=None, q2_char_inputs=None, q1_lens=None, q2_lens=None):
        # ----- Matching Layer -----
        def mp_matching_func(v1, v2, w):
            """
            :param v1: (batch, seq_len, hidden_size)
            :param v2: (batch, seq_len, hidden_size) or (batch, hidden_size)
            :param w: (l, hidden_size)
            :return: (batch, l)
            """
            seq_len = v1.size(1)

            # Trick for large memory requirement
            """
            if len(v2.size()) == 2:
                v2 = torch.stack([v2] * seq_len, dim=1)
            m = []
            for i in range(self.l):
                # v1: (batch, seq_len, hidden_size)
                # v2: (batch, seq_len, hidden_size)
                # w: (1, 1, hidden_size)
                # -> (batch, seq_len)
                m.append(F.cosine_similarity(w[i].view(1, 1, -1) * v1, w[i].view(1, 1, -1) * v2, dim=2))
            # list of (batch, seq_len) -> (batch, seq_len, l)
            m = torch.stack(m, dim=2)
            """

            # (1, 1, hidden_size, l)
            w = w.transpose(1, 0).unsqueeze(0).unsqueeze(0)
            # (batch, seq_len, hidden_size, l)
            v1 = w * torch.stack([v1] * self.l, dim=3)
            if len(v2.size()) == 3:
                v2 = w * torch.stack([v2] * self.l, dim=3)
            else:
                v2 = w * torch.stack([torch.stack([v2] * seq_len, dim=1)] * self.l, dim=3)

            m = F.cosine_similarity(v1, v2, dim=2)

            return m

        def mp_matching_func_pairwise(v1, v2, w):
            """
            :param v1: (batch, seq_len1, hidden_size)
            :param v2: (batch, seq_len2, hidden_size)
            :param w: (l, hidden_size)
            :return: (batch, l, seq_len1, seq_len2)
            """

            # Trick for large memory requirement
            """
            m = []
            for i in range(self.l):
                # (1, 1, hidden_size)
                w_i = w[i].view(1, 1, -1)
                # (batch, seq_len1, hidden_size), (batch, seq_len2, hidden_size)
                v1, v2 = w_i * v1, w_i * v2
                # (batch, seq_len, hidden_size->1)
                v1_norm = v1.norm(p=2, dim=2, keepdim=True)
                v2_norm = v2.norm(p=2, dim=2, keepdim=True)
                # (batch, seq_len1, seq_len2)
                n = torch.matmul(v1, v2.permute(0, 2, 1))
                d = v1_norm * v2_norm.permute(0, 2, 1)
                m.append(div_with_small_value(n, d))
            # list of (batch, seq_len1, seq_len2) -> (batch, seq_len1, seq_len2, l)
            m = torch.stack(m, dim=3)
            """

            # (1, l, 1, hidden_size)
            w = w.unsqueeze(0).unsqueeze(2)
            # (batch, l, seq_len, hidden_size)
            v1, v2 = w * torch.stack([v1] * self.l, dim=1), w * torch.stack([v2] * self.l, dim=1)
            # (batch, l, seq_len, hidden_size->1)
            v1_norm = v1.norm(p=2, dim=3, keepdim=True)
            v2_norm = v2.norm(p=2, dim=3, keepdim=True)

            # (batch, l, seq_len1, seq_len2)
            n = torch.matmul(v1, v2.transpose(2, 3))
            d = v1_norm * v2_norm.transpose(2, 3)

            # (batch, seq_len1, seq_len2, l)
            m = div_with_small_value(n, d).permute(0, 2, 3, 1)

            return m

        def attention(v1, v2):
            """
            :param v1: (batch, seq_len1, hidden_size)
            :param v2: (batch, seq_len2, hidden_size)
            :return: (batch, seq_len1, seq_len2)
            """

            # (batch, seq_len1, 1)
            v1_norm = v1.norm(p=2, dim=2, keepdim=True)
            # (batch, 1, seq_len2)
            v2_norm = v2.norm(p=2, dim=2, keepdim=True).permute(0, 2, 1)

            # (batch, seq_len1, seq_len2)
            a = torch.bmm(v1, v2.permute(0, 2, 1))
            d = v1_norm * v2_norm

            return div_with_small_value(a, d)

        def div_with_small_value(n, d, eps=1e-8):
            # too small values are replaced by 1e-8 to prevent it from exploding.
            d = d * (d > eps).float() + eps * (d <= eps).float()
            return n / d

        # ----- Word Representation Layer -----
        # (batch, seq_len) -> (batch, seq_len, word_dim)
        q1_lens_tmp, q1_indices = torch.sort(q1_lens, descending=True)
        q1_input_tmp = q1_inputs[q1_indices]

        q2_lens_tmp, q2_indices = torch.sort(q2_lens, descending=True)
        q2_input_tmp = q2_inputs[q2_indices]


        p = self.word_emb(q1_input_tmp)
        h = self.word_emb(q2_input_tmp)

        if self.use_char_emb:
            # (batch, seq_len, max_word_len) -> (batch * seq_len, max_word_len)
            seq_len_p = q1_char_inputs.size(1)
            seq_len_h = q2_char_inputs.size(1)

            char_p = q1_char_inputs.view(-1, self.max_word_len)
            char_h = q2_char_inputs.view(-1, self.max_word_len)

            # (batch * seq_len, max_word_len, char_dim)-> (1, batch * seq_len, char_hidden_size)
            _, (char_p, _) = self.char_LSTM(self.char_emb(char_p))
            _, (char_h, _) = self.char_LSTM(self.char_emb(char_h))

            # (batch, seq_len, char_hidden_size)
            char_p = char_p.view(-1, seq_len_p, self.char_hidden_size)
            char_h = char_h.view(-1, seq_len_h, self.char_hidden_size)

            # (batch, seq_len, word_dim + char_hidden_size)
            p = torch.cat([p, char_p], dim=-1)
            h = torch.cat([h, char_h], dim=-1)

        p = self.dropout_fun(p)
        h = self.dropout_fun(h)

        # ----- Context Representation Layer -----
        # (batch, seq_len, hidden_size * 2)
        p_packed = pack_padded_sequence(p, q1_lens_tmp, batch_first=True)
        p_out, _ = self.context_LSTM(p_packed)
        con_p, _ = pad_packed_sequence(p_out, batch_first=True)

        h_packed = pack_padded_sequence(h, q2_lens_tmp, batch_first=True)
        h_out, _ = self.context_LSTM(h_packed)
        con_h, _ = pad_packed_sequence(h_out, batch_first=True)

        _, q1_desorte_indices = torch.sort(q1_indices, descending=False)
        con_p = con_p[q1_desorte_indices]
        _, q2_desorte_indices = torch.sort(q2_indices, descending=False)
        con_h = con_h[q2_desorte_indices]

        con_p = self.dropout_fun(con_p)
        con_h = self.dropout_fun(con_h)

        # (batch, seq_len, hidden_size)
        con_p_fw, con_p_bw = torch.split(con_p, self.hidden_size, dim=-1)
        con_h_fw, con_h_bw = torch.split(con_h, self.hidden_size, dim=-1)

        # 1. Full-Matching

        # (batch, seq_len, hidden_size), (batch, hidden_size)
        # -> (batch, seq_len, l)
        mv_p_full_fw = mp_matching_func(con_p_fw, con_h_fw[:, -1, :], self.mp_w1)
        mv_p_full_bw = mp_matching_func(con_p_bw, con_h_bw[:, 0, :], self.mp_w2)
        mv_h_full_fw = mp_matching_func(con_h_fw, con_p_fw[:, -1, :], self.mp_w1)
        mv_h_full_bw = mp_matching_func(con_h_bw, con_p_bw[:, 0, :], self.mp_w2)

        # 2. Maxpooling-Matching

        # (batch, seq_len1, seq_len2, l)
        mv_max_fw = mp_matching_func_pairwise(con_p_fw, con_h_fw, self.mp_w3)
        mv_max_bw = mp_matching_func_pairwise(con_p_bw, con_h_bw, self.mp_w4)

        # (batch, seq_len, l)
        mv_p_max_fw, _ = mv_max_fw.max(dim=2)
        mv_p_max_bw, _ = mv_max_bw.max(dim=2)
        mv_h_max_fw, _ = mv_max_fw.max(dim=1)
        mv_h_max_bw, _ = mv_max_bw.max(dim=1)

        # 3. Attentive-Matching

        # (batch, seq_len1, seq_len2)
        att_fw = attention(con_p_fw, con_h_fw)
        att_bw = attention(con_p_bw, con_h_bw)

        # (batch, seq_len2, hidden_size) -> (batch, 1, seq_len2, hidden_size)
        # (batch, seq_len1, seq_len2) -> (batch, seq_len1, seq_len2, 1)
        # -> (batch, seq_len1, seq_len2, hidden_size)
        att_h_fw = con_h_fw.unsqueeze(1) * att_fw.unsqueeze(3)
        att_h_bw = con_h_bw.unsqueeze(1) * att_bw.unsqueeze(3)
        # (batch, seq_len1, hidden_size) -> (batch, seq_len1, 1, hidden_size)
        # (batch, seq_len1, seq_len2) -> (batch, seq_len1, seq_len2, 1)
        # -> (batch, seq_len1, seq_len2, hidden_size)
        att_p_fw = con_p_fw.unsqueeze(2) * att_fw.unsqueeze(3)
        att_p_bw = con_p_bw.unsqueeze(2) * att_bw.unsqueeze(3)

        # (batch, seq_len1, hidden_size) / (batch, seq_len1, 1) -> (batch, seq_len1, hidden_size)
        att_mean_h_fw = div_with_small_value(att_h_fw.sum(dim=2), att_fw.sum(dim=2, keepdim=True))
        att_mean_h_bw = div_with_small_value(att_h_bw.sum(dim=2), att_bw.sum(dim=2, keepdim=True))

        # (batch, seq_len2, hidden_size) / (batch, seq_len2, 1) -> (batch, seq_len2, hidden_size)
        att_mean_p_fw = div_with_small_value(att_p_fw.sum(dim=1), att_fw.sum(dim=1, keepdim=True).permute(0, 2, 1))
        att_mean_p_bw = div_with_small_value(att_p_bw.sum(dim=1), att_bw.sum(dim=1, keepdim=True).permute(0, 2, 1))

        # (batch, seq_len, l)
        mv_p_att_mean_fw = mp_matching_func(con_p_fw, att_mean_h_fw, self.mp_w5)
        mv_p_att_mean_bw = mp_matching_func(con_p_bw, att_mean_h_bw, self.mp_w6)
        mv_h_att_mean_fw = mp_matching_func(con_h_fw, att_mean_p_fw, self.mp_w5)
        mv_h_att_mean_bw = mp_matching_func(con_h_bw, att_mean_p_bw, self.mp_w6)

        # 4. Max-Attentive-Matching

        # (batch, seq_len1, hidden_size)
        att_max_h_fw, _ = att_h_fw.max(dim=2)
        att_max_h_bw, _ = att_h_bw.max(dim=2)
        # (batch, seq_len2, hidden_size)
        att_max_p_fw, _ = att_p_fw.max(dim=1)
        att_max_p_bw, _ = att_p_bw.max(dim=1)

        # (batch, seq_len, l)
        mv_p_att_max_fw = mp_matching_func(con_p_fw, att_max_h_fw, self.mp_w7)
        mv_p_att_max_bw = mp_matching_func(con_p_bw, att_max_h_bw, self.mp_w8)
        mv_h_att_max_fw = mp_matching_func(con_h_fw, att_max_p_fw, self.mp_w7)
        mv_h_att_max_bw = mp_matching_func(con_h_bw, att_max_p_bw, self.mp_w8)

        # (batch, seq_len, l * 8)
        mv_p = torch.cat(
            [mv_p_full_fw, mv_p_max_fw, mv_p_att_mean_fw, mv_p_att_max_fw,
             mv_p_full_bw, mv_p_max_bw, mv_p_att_mean_bw, mv_p_att_max_bw], dim=2)
        mv_h = torch.cat(
            [mv_h_full_fw, mv_h_max_fw, mv_h_att_mean_fw, mv_h_att_max_fw,
             mv_h_full_bw, mv_h_max_bw, mv_h_att_mean_bw, mv_h_att_max_bw], dim=2)

        mv_p = self.dropout_fun(mv_p)
        mv_h = self.dropout_fun(mv_h)

        # ----- Aggregation Layer -----
        # (batch, seq_len, l * 8) -> (2, batch, hidden_size)
        _, (agg_p_last, _) = self.aggregation_LSTM(mv_p)
        _, (agg_h_last, _) = self.aggregation_LSTM(mv_h)

        # 2 * (2, batch, hidden_size) -> 2 * (batch, hidden_size * 2) -> (batch, hidden_size * 4)
        x = torch.cat(
            [agg_p_last.permute(1, 0, 2).contiguous().view(-1, self.hidden_size * 2),
             agg_h_last.permute(1, 0, 2).contiguous().view(-1, self.hidden_size * 2)], dim=1)
        x = self.dropout_fun(x)

        # ----- Prediction Layer -----
        x = F.tanh(self.pred_fc1(x))
        x = self.dropout_fun(x)
        x = self.pred_fc2(x)

        return x

 

  • 4
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
TF-IDF(Term Frequency-Inverse Document Frequency)是一种常用的文本特征提取方法,通常用于信息检索与文本分类等任务中。它的基本思想是:对于一个给定的文本,TF-IDF可以帮助我们评估每个单词对于这个文本的重要程度,从而提取出文本中最具有代表性的单词。 具体来说,TF-IDF模型文本中每个单词的重要程度表示为一个TF-IDF值。其中,TF值表示单词在文本中出现的次数,IDF值表示单词在整个语料库中出现的频率。TF-IDF值可以通过将TF值乘以IDF值来计算得到。这样计算出来的TF-IDF值将更加重视那些在当前文本中频繁出现,但在整个语料库中出现较少的单词。 在实际应用中,我们通常使用Python中的sklearn库来构建TF-IDF模型。以下是使用sklearn库构建TF-IDF模型的代码示例: ```python from sklearn.feature_extraction.text import TfidfVectorizer # 定义文本数据 text_data = ["This is the first document.", "This is the second document.", "And this is the third one.", "Is this the first document?"] # 创建TfidfVectorizer对象,并调用fit_transform()方法构建TF-IDF模型 vectorizer = TfidfVectorizer() tfidf = vectorizer.fit_transform(text_data) # 输出TF-IDF模型 print(tfidf.toarray()) ``` 以上代码中,我们首先定义了一组文本数据,然后创建了一个TfidfVectorizer对象,并调用它的fit_transform()方法来构建TF-IDF模型。最后,我们输出了构建好的TF-IDF模型。输出的结果是一个矩阵,其中每行表示一个文本,每列表示一个单词,矩阵中的值表示对应单词在对应文本中的TF-IDF值。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值