Comp-Agg (A Compare-Aggregate Model for Matching Text Sequences)

本文探讨了一种改进的比较-聚合框架,通过词嵌入和CNN聚合,针对文本向量进行高效匹配。实验通过多种数据集验证模型的泛化能力,并发现element-wise比较函数优于复杂网络。研究还揭示了门控单元和注意力机制在语义特征提取中的作用,以及向量差积在特征提取中的实用性。
摘要由CSDN通过智能技术生成

CompareAggregate研究意义:

1、采用“比较-聚合”框架,并对此进行改进

2、采用多种数据集验证模型的泛化性

 

本文主要结构如下所示:

一、Abstract

      摘要部分主要介绍本文利用词嵌入作为输入,CNN网络作为聚合函数,提出比较聚合框架;关注于不同的比较函数来对文本向量进行匹配;并且使用不同的几份数据评估模型;

基于element-wise的比较函数可能会比复杂神经网络效果更好。

二、Introudction

      首先提及了很多自然语言处理任务都需要对两个或多个句子进行匹配,然后作出决定。

       

三、Method

             主要介绍模型的结构以及六个不同的比较函数

四、Experiment

            实验部分主要介绍不同比较函数以及组合函数在四个不同任务数据集合的效果,证明组合比较函数模型的有效性

五、Related Work

           相关工作部分简单的描述了孪生网络、注意力机制以及比较-聚合网络相关的应用

六、Conclusions

            最后一部分总结了本文系统分析“比较-聚合”模型在四个不同任务数据集上的有效性,此外还提出了词级别的比较函数element-wise 比较函数表现好于其它函数,并且根据实验结果很多不同任务可以共享“比较-聚合”结构,在未来的任务中,可以把它使用在多任务学习中。

         关键点: 采用“比较-聚合”框架;利用多种数据集证明模型的有效性;提出多种比较函数并探究了交互的最佳方式

         创新点: 利用门控单元提取语义特征,利用注意力机制完成句子权重匹配,利用向量的差和积进行特征提取

七、Code

# -*- coding: utf-8 -*-

# @Time : 2021/2/14 下午2:07
# @Author : TaoWang
# @Description : "比较-聚合" 模型结构


import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torch.autograd import Variable


# 预处理层
class Preprocess(nn.Module):
    def __init__(self, in_features, out_features):
        """
        :param in_features: 
        :param out_features: 
        """
        super().__init__()
        self.Wi = nn.Parameter(torch.randn(in_features, out_features))
        self.bi = nn.Parameter(torch.randn(out_features))
        
        self.wu = nn.Parameter(torch.randn(in_features, out_features))
        self.bu = nn.Parameter(torch.randn(out_features))
        
    def forward(self, x):
        """
        :param x: 
        :return: 
        """
        gate = torch.matmul(x, self.Wi)
        gate = torch.sigmoid(gate + self.bi.expand_as(date))
        
        out = torch.matmul(x, self.Wu)
        out = torch.tanh(out + self.bu.expand_as(out))
        
        return gate * out
    
    
# 注意力层
class Attention(nn.Module):
    def __init__(self):
        super().__init__()
        self.wg = nn.Parameter(torch.randn(hidden_size, hidden_size))
        self.bg = nn.Parameter(torch.randn(hidden_size))
        
    def forward(self, q, a):
        """
        :param q: 
        :param a: 
        :return: 
        """
        G = torch.matmul(q, self.wg)
        G = G + self.bg.expand_as(G)
        G = torch.matmul(G, a.permute(0, 2, 1))
        G = torch.softmax(G, dim=1)
        H = torch.matmul(G.permute(0, 2, 1), q)
        
        return H
    

# 模型比较层
class Compare(nn.Module):
    def __init__(self):
        super().__init__()
        self.W = nn.Parameter(torch.randn(2*hidden_size, hidden_size))
        self.b = nn.Parameter(torch.randn(hidden_size))
        
    def forward(self, h, a):
        """
        :param h: 
        :param a: 
        :return: 
        """
        sub = (h - a) * (h - a)
        mult = h * a
        T = torch.matmul(torch.cat([sub, mult], dim=2), self.W)
        T = torch.relu(T + self.b.expand_as(T))
        
        return T
    

# 模型比较聚合层汇总

class CompAgg(torch.nn.Module):
    def __init__(self):
        super(CompAgg, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_size)
        self.embedding.weight.data.copy_(torch.from_numpy(embed))
        self.preprocess = Preprocess(embedding_size, hidden_size)
        self.attention = Attention()
        self.compare = Compare()
        self.aggregate = nn.Conv1d(in_channels=max_len, out_channels=window, kernel_size=3, stride=1, padding=1)
        self.predict = nn.Linear(window * hidden_size, classes)
        
    def forward(self, q, a):
        """
        :param q: 设 q长度 20
        :param a: 设 a长度 40
        :return: 
        """
        # emb_q: batch * 20 * 200, emb_a: batch * 40 * 200
        emb_q, emb_a = self.embedding(q), self.embedding(a)
        # q_bar: batch * 20 * 100, a_bar: batch * 40 * 100
        q_bar, a_bar = self.preprocess(emb_q), self.preprocess(emb_a)
        # H: batch * 40 * 100
        H = self.attention(q_bar, a_bar)
        # T: batch * 40 * 100
        T = self.compare(H, a_bar)
        # r: batch * 3 * 100
        r = self.aggregate(T)
        # r: batch * 300
        r = r.view(-1, window * hidden_size)
        # out: batch * 3
        out = self.predict(r)
        
        return out

 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>