百度语义匹配模型-simnet代码整理

关于simnet提出背景以及可以应用的地方,机器之心这篇文章里已经说得非常详细了。https://www.jiqizhixin.com/articles/2017-06-15-5

本文主要是记录一下自己使用simnet做语义匹配任务的流程,并对代码各个模块的功能进行整理和解释。

一、模型结构

                                                      

SimNet 框架如上图所示,主要分为输入层、表示层和匹配层。

各个层的功能:

1.输入层

该层通过 look up table 将文本词序列转换为 word embedding 序列。

2.表示层

该层主要功能是由词到句的表示构建,或者说将序列的孤立的词语的 embedding 表示,转换为具有全局信息的一个或多个低维稠密的语义向量。最简单的是 Bag of Words(BOW)的累加方法,除此之外,我们还在 SimNet 框架下研发了对应的序列卷积网络(CNN)、循环神经网络(RNN)等多种表示技术。当然,在得到句子的表示向量后,也可以继续累加更多层全连接网络,进一步提升表示效果。

3.匹配层

该层利用文本的表示向量进行交互计算,根据应用的场景不同,有两种匹配算法。

Representation-based Match和Interaction-based Match。

而在Representation-based Match有两种计算方式:

                                      

且通常选用的都是Representation-based Match的方法。

                                      

若采用pair-wise Ranking Loss 来进行 SimNet 的训练。以网页搜索任务为例,假设搜索查询文本为 Q,相关的一篇文档为 D+,不相关的一篇文档为 D-,二者经过 SimNet 网络得到的和 Q 的匹配度得分分别为 S(Q,D+) 和 S(Q,D-),而训练的优化目标就是使得 S(Q,D+)>S(Q,D-)。

实际中,我们一般采用 Max-Margin 的 Hinge Loss:

                                                                             max{0,margin-(S(Q,D+)-S(Q,D-))}

二、运用

使用这个开源代码来完成语义匹配任务通常需要以下几个步骤:

1. 首先我们需要将需要计算匹配度的句对进行转换,变成tfrecord的格式。

2.搭建网络。

3.读取数据进行训练

4.进行测试。

simnet中提供了多种网络供选择,并且也有不同的loss可以选择进行优化,这里我们只选用pointwise格式的数据,用MLPCnn网络,SoftmaxWithLoss的损失函数来构建我们特定的模型。

MPLCNN网络的结构:

                                        

输入-输入进行embedding-embedding结果进CNN-CNN结果经过relu-relu出来后对左右进行concat-concat后接全连接层(

如果是pointwise,现将左右结果进行concat,然后通过fc1,然后通过relu,再通过fc2,输出pred(全连接层实际上是X*W+b的一个计算。

如果是pairwise,则relu出来后不需要讲左右进行concat,直接将relu出来的左右结果经过fc1,fc1的输出结果经过cosine Layer然后输出预测值)

代码如下:

# coding:utf-8
from collections import Counter
import logging
import numpy
import time
import sys
import os
import json
import tensorflow as tf
import traceback
import math

fwords = "data/word2id.json"

forigin_train_corpus = "data/train.sample"
forigin_test_corpus = "data/test.sample"

ftrain_pointwise = "data/kesci_train_pointwise.txt"
ftest_pointwise = "data/kesci_test_pointwise.txt"

ftrain_pointwise_data = "kesci_train_pointwise.data"
ftest_pointwise_data = "kesci_test_pointwise.data"

############################################################
# 0: 转换数据格式
def data2pointwise(fin, fout, fwords):
    word2id = json.load(open(fwords, "r"))
    with open(fin, "r") as fr, open(fout, "w") as fw:
        for line in fr:
            line = line.strip().split(",")
            q = [word2id[w] for w in line[1].split() if w in word2id]
            d = [word2id[w] for w in line[3].split() if w in word2id]

            if len(q) < 5 or len(d) < 5:
                continue

            q = list(map(str, q))
            d = list(map(str, d))
            fw.write(" ".join(q) + "\t" + " ".join(d) + "\t" + line[4] &
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值