关于simnet提出背景以及可以应用的地方,机器之心这篇文章里已经说得非常详细了。https://www.jiqizhixin.com/articles/2017-06-15-5
本文主要是记录一下自己使用simnet做语义匹配任务的流程,并对代码各个模块的功能进行整理和解释。
一、模型结构
SimNet 框架如上图所示,主要分为输入层、表示层和匹配层。
各个层的功能:
1.输入层
该层通过 look up table 将文本词序列转换为 word embedding 序列。
2.表示层
该层主要功能是由词到句的表示构建,或者说将序列的孤立的词语的 embedding 表示,转换为具有全局信息的一个或多个低维稠密的语义向量。最简单的是 Bag of Words(BOW)的累加方法,除此之外,我们还在 SimNet 框架下研发了对应的序列卷积网络(CNN)、循环神经网络(RNN)等多种表示技术。当然,在得到句子的表示向量后,也可以继续累加更多层全连接网络,进一步提升表示效果。
3.匹配层
该层利用文本的表示向量进行交互计算,根据应用的场景不同,有两种匹配算法。
Representation-based Match和Interaction-based Match。
而在Representation-based Match有两种计算方式:
且通常选用的都是Representation-based Match的方法。
若采用pair-wise Ranking Loss 来进行 SimNet 的训练。以网页搜索任务为例,假设搜索查询文本为 Q,相关的一篇文档为 D+,不相关的一篇文档为 D-,二者经过 SimNet 网络得到的和 Q 的匹配度得分分别为 S(Q,D+) 和 S(Q,D-),而训练的优化目标就是使得 S(Q,D+)>S(Q,D-)。
实际中,我们一般采用 Max-Margin 的 Hinge Loss:
max{0,margin-(S(Q,D+)-S(Q,D-))}
二、运用
使用这个开源代码来完成语义匹配任务通常需要以下几个步骤:
1. 首先我们需要将需要计算匹配度的句对进行转换,变成tfrecord的格式。
2.搭建网络。
3.读取数据进行训练
4.进行测试。
simnet中提供了多种网络供选择,并且也有不同的loss可以选择进行优化,这里我们只选用pointwise格式的数据,用MLPCnn网络,SoftmaxWithLoss的损失函数来构建我们特定的模型。
MPLCNN网络的结构:
输入-输入进行embedding-embedding结果进CNN-CNN结果经过relu-relu出来后对左右进行concat-concat后接全连接层(
如果是pointwise,现将左右结果进行concat,然后通过fc1,然后通过relu,再通过fc2,输出pred(全连接层实际上是X*W+b的一个计算。
如果是pairwise,则relu出来后不需要讲左右进行concat,直接将relu出来的左右结果经过fc1,fc1的输出结果经过cosine Layer然后输出预测值)
代码如下:
# coding:utf-8
from collections import Counter
import logging
import numpy
import time
import sys
import os
import json
import tensorflow as tf
import traceback
import math
fwords = "data/word2id.json"
forigin_train_corpus = "data/train.sample"
forigin_test_corpus = "data/test.sample"
ftrain_pointwise = "data/kesci_train_pointwise.txt"
ftest_pointwise = "data/kesci_test_pointwise.txt"
ftrain_pointwise_data = "kesci_train_pointwise.data"
ftest_pointwise_data = "kesci_test_pointwise.data"
############################################################
# 0: 转换数据格式
def data2pointwise(fin, fout, fwords):
word2id = json.load(open(fwords, "r"))
with open(fin, "r") as fr, open(fout, "w") as fw:
for line in fr:
line = line.strip().split(",")
q = [word2id[w] for w in line[1].split() if w in word2id]
d = [word2id[w] for w in line[3].split() if w in word2id]
if len(q) < 5 or len(d) < 5:
continue
q = list(map(str, q))
d = list(map(str, d))
fw.write(" ".join(q) + "\t" + " ".join(d) + "\t" + line[4] &