基于CNN和序列标注的对联机器人 | 附数据集 & 开源代码

640


作者丨苏剑林

单位丨广州火焰信息科技有限公司

研究方向丨NLP,神经网络

个人主页丨kexue.fm


缘起


前几天看到了这个脑洞清奇的对联AI,大家都玩疯了一文,觉得挺有意思,难得的是作者还整理并公开了数据集,所以决定自己尝试一下。


动手


“对对联”,我们可以看成是一个句子生成任务,可以用 Seq2Seq 完成,跟我之前写的玩转Keras之Seq2Seq自动生成标题一样,稍微修改一下输入即可。上面提到的文章所用的方法也是 Seq2Seq,可见这算是标准做法了。


分析


然而,我们再细想一下就会发现,相对于一般的句子生成任务,“对对联”有规律得多:1)上联和下联的字数一样;2)上联和下联的每一个字几乎都有对应关系。


如此一来,其实对对联可以直接看成一个序列标注任务,跟分词、命名实体识别等一样的做法即可。这便是本文的出发点。 


说到这,其实本文就没有什么技术含量了,序列标注已经是再普通不过的任务了,远比一般的 Seq2Seq 来得简单。


所谓序列标注,就是指输入一个向量序列,然后输出另外一个通常长度的向量序列,最后对这个序列的“每一帧”进行分类。相关概念来可以在简明条件随机场CRF介绍 | 附带纯Keras实现一文进一步了解。


模型


本文直接边写代码边介绍模型。如果需要进一步了解背后的基础知识的读者,还可以参考《中文分词系列:基于双向LSTM的Seq2Seq字标注》[1]《中文分词系列:基于全卷积网络的中文分词》[2]《基于CNN和VAE的作诗机器人:随机成诗》[3]


我们所用的模型代码如下:


x_in = Input(shape=(None,))
x = x_in
x = Embedding(len(chars)+1, char_size)(x)
x = Dropout(0.25)(x)

x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)
x = gated_resnet(x)

x = Dense(len(chars)+1, activation='softmax')(x)

model = Model(x_in, x)
model.compile(loss='sparse_categorical_crossentropy',
              optimizer='adam')


其中 gated_resnet 是我定义的门卷积模块:


def gated_resnet(x, ksize=3):
    # 门卷积 + 残差
    x_dim = K.int_shape(x)[-1]
    xo = Conv1D(x_dim*2, ksize, padding='same')(x)
    return Lambda(lambda x: x[0] * K.sigmoid(x[1][..., :x_dim]) \
                            + x[1][..., x_dim:] * K.sigmoid(-x[1][..., :x_dim]))([x, xo])


仅此而已,就这样完了,剩下的都是数据预处理的事情了。当然,读者也可以尝试也可以把 gated_resnet 换成普通的层叠双向 LSTM,但我实验中发现层叠双向 LSTM 并没有层叠 gated_resnet 效果好,而且 LSTM 相对来说也很慢。


效果


训练的数据集来自以下链接,感谢作者的整理。


https://github.com/wb14123/couplet-dataset


完整代码:


https://github.com/bojone/seq2seq/blob/master/couplet_by_seq_tagging.py


训练过程:


640?wx_fmt=png

 对联机器人训练过程


部分效果:


640?wx_fmt=png


看起来还是有点味道的。注意“晚风摇树树还挺”是训练集的上联,标准下联是“晨露润花花更红”,而模型给出来的是“夜雨敲花花更香”,说明模型并不是单纯地记住训练集的,还是有一定的理解能力;甚至我觉得模型对出来的下联更生动一些。


总的来说,基本的字的对应似乎都能做到,就缺乏一个整体感。总体效果没有下面两个好,但作为一个小玩具,应该能让人满意了。


王斌版AI对联:https://ai.binwang.me/couplet/

微软对联:https://duilian.msra.cn/default.htm


结语


最后,也没有什么好总结的。我就是觉得这个对对联应该算是一个序列标注任务,所以就想着用一个序列标注的模型来试试看,结果感觉还行。


当然,要做得更好,需要在模型上做些调整,还可以考虑引入 Attention 等,然后解码的时候,还需要引入更多的先验知识,保证结果符合我们对对联的要求。这些就留给有兴趣做下去的读者继续了。


相关链接


[1] https://kexue.fm/archives/3924

[2] https://kexue.fm/archives/4195

[3] https://kexue.fm/archives/5332


640?



点击以下标题查看作者其他文章: 





640?#投 稿 通 道#

 让你的论文被更多人看到 



如何才能让更多的优质内容以更短路径到达读者群体,缩短读者寻找优质内容的成本呢? 答案就是:你不认识的人。


总有一些你不认识的人,知道你想知道的东西。PaperWeekly 或许可以成为一座桥梁,促使不同背景、不同方向的学者和学术灵感相互碰撞,迸发出更多的可能性。 


PaperWeekly 鼓励高校实验室或个人,在我们的平台上分享各类优质内容,可以是最新论文解读,也可以是学习心得技术干货。我们的目的只有一个,让知识真正流动起来。


来稿标准:

• 稿件确系个人原创作品,来稿需注明作者个人信息(姓名+学校/工作单位+学历/职位+研究方向) 

• 如果文章并非首发,请在投稿时提醒并附上所有已发布链接 

• PaperWeekly 默认每篇文章都是首发,均会添加“原创”标志


? 投稿邮箱:

• 投稿邮箱:hr@paperweekly.site 

• 所有文章配图,请单独在附件中发送 

• 请留下即时联系方式(微信或手机),以便我们在编辑发布时和作者沟通




?


现在,在「知乎」也能找到我们了

进入知乎首页搜索「PaperWeekly」

点击「关注」订阅我们的专栏吧



关于PaperWeekly


PaperWeekly 是一个推荐、解读、讨论、报道人工智能前沿论文成果的学术平台。如果你研究或从事 AI 领域,欢迎在公众号后台点击「交流群」,小助手将把你带入 PaperWeekly 的交流群里。


640?

▽ 点击 | 阅读原文 | 查看作者博客

PyTorch是一种流行的深度学习框架,可以用于构建卷积神经网络(CNN)等模型。在猫狗分类任务中,我们可以使用PyTorch来训练一个CNN模型来对猫和狗的图像进行分类。 首先,我们需要准备一个猫狗分类的数据集。可以在网上找到已经标注好的猫狗图像数据集,例如Kaggle上的猫狗大战数据集。这个数据集包含了数千张猫和狗的图像,以及它们对应的标签。 接下来,我们需要导入必要的PyTorch库和模块,例如torch、torchvision等。 然后,我们需要定义一个CNN模型。可以使用PyTorch提供的nn模块来搭建一个简单的CNN网络,包括卷积层、池化层和全连接层等。可以根据具体任务的需求和网络结构进行调整。 在搭建好网络之后,我们需要定义损失函数和优化器。对于猫狗分类任务,可以使用交叉熵损失函数来衡量预测结果和真实标签的差异,并选择适当的优化器,如SGD、Adam等来更新模型的参数。 接下来,我们可以始训练模型。将数据集分为训练集和测试集,使用训练集来迭代地更新模型参数,计算损失函数并通过反向传播算法更新模型。在每个epoch结束后,使用测试集来评估模型的性能,如准确率、精确率、召回率等。 最后,我们可以使用训练好的模型对新的猫狗图像进行分类预测。将图像传入模型中,得到对应的预测结果,即猫或狗的标签。 总结来说,PyTorch可以用于搭建CNN模型进行猫狗分类任务。需要准备好猫狗分类的数据集,在训练过程中使用损失函数和优化器来更新模型参数,并使用测试集来评估模型性能。最终可以使用训练好的模型对新的猫狗图像进行分类预测。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值