手把手教你实现命名实体识别

1.互联网金融新实体发现(数据获取)

比赛链接: 互联网金融新实体发现.
本博客主要参考他的文章和代码: 阿力阿哩哩.感兴趣的话可以关注他的知乎、公众号以及B站账号。

2.环境搭建

(1)硬件环境:
操作系统:windows 10或者 linux(Ubuntu 16~18) (本人使用的windows 10)
硬件配置:主要是显卡要求:1660Ti 6G(最起码要保证有一个显卡)
(2)软件环境:
这里最好自己创建一个虚拟环境,然后在里面配置一下各种库的版本。

  • 创建虚拟环境:conda create –n ccf_ner python==3.6
  • 进入虚拟环境:conda activate ccf_ner
  • tensorflow-gpu==1.10
  • cudatoolkit==9.0
  • cudnn=7.0
  • tqdm = 4.60.0
  • pandas==0.25.3
  • numpy==1.14.5
    在这里插入图片描述

3.赛题分析

(1)实验流程
在这里插入图片描述

(2)代码结构

  • preprocess/preprocess.py: 对原始数据进行预处理,包含数据清洗、数据切割、数据格式转换

  • model.py:单模构建代码,包含BERT+BILISTM-CRF、BERT+IDCNN-CRF

  • utils.py:DataIterator数据迭代器, 用于生成batch数据喂入模型

  • train_fine_tune.py:模型的训练(即模型参数微调)

  • predict.py:模型预测(微调好的模型用于测试集的预测)

  • ensemble/ensemble.py:对predict.py模型生成的文件进行复原,生成单模的文字预测结果

  • bert/tokenization.py:BERT源码分词工具

  • tf_utils文件夹:BERT源码的修改以及CRF等的开源代码

  • config.py:超参数设置和预训练模型等的路径设置

  • optimization.py:参数优化器

  • post_process文件夹:

  • [1] get_ensemble_final_result.py: 对ensemble.py生成的单模文字结果进行拼接,因为在预处理的时候将测试集切成了多份。

  • [2] post_ensemble_final_result.py: 对get_ensemble_final_result.py 生成的文件进行后处理,得到最终的单模文字结果

4.代码解析

4.1数据预处理

(1)首先查看一下原始数据:
在这里插入图片描述
样本数据包含文本标识号(ID)、题目(title)、正文(text)、识别出的未知实体(unknownEntity)四列,这个阶段需要做的事情就是要将数据样本处理成BIEO或BIO等格式。预处理之后的数据格式如下:
在这里插入图片描述
(2)代码解析

  • 导入相关库
import pandas as pd
import codecs
import re
import json
import sys
train_df = pd.read_csv('Round2_train.csv', encoding='utf-8')
test_df = pd.read_csv('Round2_Test.csv', encoding='utf-8')
len_treshold =512 - 2  #  每条数据的最大长度, 留下两个位置给[CLS]和[SEP]
  • 找出所有的非数字、非中文、分英文的符号
additional_chars = set()
for t in list(test_df.text) + list(train_df.text):
    additional_chars.update(re.findall(u'[^\u4e00-\u9fa5a-zA-Z0-9\*]', str(t)))
    # u4e00-\u9fa5a 中文字符的十六进制 、a-zA-Z0-9 大小写字母以及数字 ^表示取反
extra_chars = set("!#$%&\()*+,-./:;<=>?@[\\]^_`{|}~!#¥%&?《》{}“”,:‘’。()·、;【】")#需要保留的字符
additional_chars = additional_chars.difference(extra_chars)#滤除除了extra_chars的字符
  • 定义stop_word
def stop_words(x):
    try:
        x = x.strip() #用于移除字符串头尾指定的字符
    except:
        return ''
    x = re.sub('{IMG:.?.?.?}', '', x)
    x = re.sub('<!--IMG_\d+-->', '', x)
    x = re.sub('(https?|ftp|file)://[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]', '', x)  # 过滤网址
    x = re.sub('<a[^>]*>', '', x).replace("</a>", "")  # 过滤a标签
    x = re.sub('<P[^>]*>', '', x).replace("</P>", "")  # 过滤P标签
    x = re.sub('<strong[^>]*>', ',', x).replace("</strong>", "")  # 过滤strong标签
    x = re.sub('<br>', ',', x)  # 过滤br标签
    x = re.sub('www.[-A-Za-z0-9+&@#/%?=~_|!:,.;]+[-A-Za-z0-9+&@#/%=~_|]', '', x).replace("()", "")  # 过滤www开头的网址
    x = re.sub('\s', '', x)   # 过滤不可见字符
    x = re.sub('Ⅴ', 'V', x)
    
    
    for wbad in additional_chars:
        x = x.replace(wbad,'')
    return x

-将数据中标题和文本合并,数据中的实体主要在title和text上

train_df['text'] =  train_df['title'].fillna('') + train_df['text'].fillna('')
test_df['text'] =  test_df['title'].fillna('') + test_df['text'].fillna('')
#合并
  • 使用停止词滤除数据中的噪声并对缺失值进行处理
train_df['text'] = train_df['text'].apply(stop_words)
test_df['text'] = test_df['text'].apply(stop_words)
train_df = train_df.fillna('')  #对缺失值进行处理
  • 找出错误标签
label_list = train_df['unknownEntities'].tolist()
text_list =  train_df['text'].tolist()  #由于之前的预处理已经将title和text合并
id_list =  train_df['id'].tolist()
false_get_id = []
false_get_label = []
for i,label in enumerate(label_list):
    text = text_list[i]
    idx = id_list[i]
    l_l = label.split(';')
    not_in = []
    for li in l_l:
        if li not in text:  #如果实体标签不在文本中
            not_in.append(i)
    if len(not_in) > 0:
        false_get_id.append(idx)
        false_get_label.append(label)
        
  • 修复错误标签
repair_id_label = ['大象健康科技有限公司;健康猫', '人人爱家金融', '速借贷;有信钱包', '速借贷;有信钱包', '速借贷;有信钱包',
                   '速借贷;有信钱包', '软银;(必发)BETFAIR;火币', 'ATC国际期货;香港恒利金业;嘉信金服', 'Finci芬吉', '闪电借款;掌众财富',
                   '陀螺', '', '', '大象健康科技有限公司;健康猫', '宝点网;e租宝',  '大象健康科技有限公司;健康猫', '大象健康科技有限公司',
                   '', 'Exp金融资产;expasset项目;Plustoken钱包;expasset;Exp亚斯特', '钱-宝-网', 'Vpay支付', '', '盈康易和元消痛贴',
                   'Bitfund', 'GEC币', '时代证卷', '火星钱包;火星数字资产银行', '中汇国际期货;中宏资管;先汇国际', 'ST国际(搜宝国际)',
                   '和信售货机限公司', '洛安期货;昶胜国际;BKB数字货币;中恒策略;MBGMarkets', '', '', '', '沪深理财;首投理财',
                   'GEC币', '', '安盛', '', '智慧晶商城', '正宇控股', 'ShapeShift;Coinbase;CactusCustody;GenesisCapital;贝宝金融;比特大陆;Matrix;Matrixport;BitGoInc.;MatrixportBeta',
                   'wotoken', 'wotoken;WoToken', 'wotoken;Wotoken', 'wotoken;WoToken', 'wotoken;WoToken', '小微E贷通', 'ECOIN', 'GEC环保币',
                   'PTFX普顿外汇;PTFX;PTFX(普顿);聚宝金融;马胜金融;3M;IGOFX;金殿环球;OCTO;澳沃国际;期海聚金;期海财经', '一加二金融;富车在线;万贝贷;牛娃互联网金融;古德金融超市',
                   'JOJOMARKETS', '团贷网;钱香金融;金沙江;点亮;磐庆;玖臻资本;PPmoney;爱钱进;熊猫金库;饭饭金服;惠人贷;拍拍贷;麻袋财富', 'EOS',
                   'ATC艾特币;ZFB-致富链;趣步;巴特币;TBC天宝币;艾特币;智慧晶',
                   '有融网;银谷财富;爱福家;银豆网;长沙赛鼎生物科技有限公司;金瑞龙理财;湖南振湘网络科技有限公司;伊思多尔', '乐易Bank', '快刷;支付通;金中宝',
                   '应价零批;Lnko;plustoken;金网安泰;信雅达;PlusToken;lnko',  '普信金服APP;宏利保险;北京征和惠通基金管理有限公司;征和惠通;信中利;征和控股集团', '普信金服APP',
                   '普信金服APP', '普信金服APP', '东霖国际;波场超级社区;Plustoken;PlusToken;GCG钜富;PTFX普顿;RCFX;GCG钜富金融;GCGASIA鉅富金融;GCGASIA;香港钜富金融公司;外汇米治;SOXEX交易所;闪链SHE;火币HT',
                   '北斗股权;北斗生态圈;北斗期权;易购商城;趣步;循环信用卡', 'MGCTOKEN;mgctoken;MGCtoken',  '安徽天策;Plustoken;Lnko钱包;SOXEX交易所;波场超级社区;PlusToken;米治;PTFX',
                   '淘优乐;链豆', '', '雅布力', '友信证券;友信智投', '', '华盈城市集团;厦门中海航集团;恒优国际集团', '图灵资本;道生资本;小智投资;银杏谷资本', '',
'金蛋理财;网利宝;金百万;大麦理财;伽满优;玖富;米族金融;合众e贷;拼客顺风车;麦麦理财;盘龙财富;翼龙贷;证券单;微贷网;指旺财富;银谷在线;华侨宝;投哪网;团贷网;陆金所;积木盒子;悟空理财;广信贷;向上金服;银多网;花生米富;用钱宝;乾包网;盈盈理财;海鹭金服;智新宝;京贷金服;点牛金融;融信网;易享贷;51人品;金融工厂',
                   'UBank', '永利金控控股有限公司;UBS香港', '金机宝', '3m3.0互助;mmm3.0', '红马优购;车智汇狗狗币;酒链世界;霍特币;趣步;dogx钱包;以太坊;闪电鸡',
                   'ALicoin;TheDAO项目', '四川创梦森林软件科技有限公司;Cmsl', 'FCoin;火币', 'DEtoken', 'dogx;dogxwallet', 'DragonEx龙网;IOST',  'EXPASSET;EXP', '利润钱包;EXP-ASSET', 'FCoin',
                   'Finci芬吉;嘉实资本;TF环球金融;众元国际;捷盛国际;金源财富;GMO奥诺;鼎盈信投;中恒策略;億運富國際;速汇国际;辉立国际;世纪金业;艾斯国际;金山金融;MALAFY;宏源国际;鸿安资本;DGINT;问鼎财富;鸿昇国际;鸿翔国际;圆汇金融;唐印国际;恒牛策略;东财国际;创昇国际;锦丰国际',
                   'GEC', 'GEC', 'GEC', '瑞波(XRP)', 'HES和氏币', 'iBank数字钱包', 'IOToken(米奥钱包);IOToken', 'JMC', 'Jojomarkets', 'KB数字资产兑换平台', 'LongBit',
                   'LXFX;丰盈服务;立盟策略;易配配资;捷盈资本;众昇策略;boss金服', 'DLC数字货币;MCC数字矿工币', 'MChain;EXX交易所;蜂鸟财经;道轩资本', 'Mchain;MChain', 'PlusToken',
                   'ProsperToken;LTC莱特币;EOS柚子;狗狗币;BCH比特现金;XRP瑞波币;osp;达世币;DOGE狗狗币;火币;plustoken;ETC以太经典;DASH达世币', 'SKY;sky云世纪', '华润商业;中星集团;金丰投资;上海古北集团',
                   '速通宝Vpay;速通宝;瑞波币', 'WAD国际拆分理财平台','WAD国际金融平台', 'WBEX', 'WPAY', 'YouBank', 'jojomarkets', '惠恩商城;MGC钱包;东霖国际', '百信通国际;CITIGLOBAL花旗国际;海.贝国际;海慧通;众元国际;恒牛策略;世纪金业;诺安期货;金源财富;白象国际;辉立国际;HATSKY;博迈在线;bitkoc德林资本;艾斯国际;JTL国际;长虹资本;HDI国际;嘉晟财富;SpeedTrad速汇国际;mrt捷盈资本;万鼎国际;信融期权;恒利金业;britrading;新源财经;东吴金服;创昇国际;CXM希盟;宏源国际;旭升策略;富通国际;海利国际;合欣国际;东财国际;九梦财富;中赢国际EGML环球金融;国人策略;优信外汇;汇丰联合;鸿运信投;鼎盈信投;信邦策略;宏观策略;聚富策略;汇融国际',
                    '北京丽泽金融商务区控股有限公司;北京金唐天润置业发展集团;北京戴德梁行物业管理有限公司', 'Drivechain;侧链(Sidechain)', '58coin', 'OK链', '', 'Vpay;VPAY支付', '',
                   '菠菜理财;上海犇序实业有限公司;象象财富;小灰熊金服;云端金融;掌悦理财;央金所;巨人理财;利利金服', 'Guardian', '钜派投资集团;布尔金融', '环球金融;华宝基金', '成都奇米达',
                   '成都潜隆贷;麻袋理财;君融贷;沃时代;金银猫;恒信易贷;联连理财', '成都潜隆贷;宝通网;汉金所;鑫鹏贷;中旭鸿基投资;丰鼎金融;联鼎投资;昊祥投资;国有投资展恒理财',
                   '捷信(中国)金融公司;捷信', '侠侣联盟厦门侠网旅游服务有限公司;厦门侠网旅游服务有限公司;旅划算广州市旅划算国际旅行社有限公司',  '承返网;承返网(承返(广州)网络科技有限公司)',
                   '合拍贷;春天金融;稳银在线;可易金融;起点贷;深圳盈泰联合投资管理有限公司;深圳智鼎投资管理有限公司;深圳市中云金融服务有限公司;小宝金服;盈泰联合;印子坊;浙江颐荣资产管理有限公',
                   '维卡币', '优客工场;人人车;唱吧;考拉先生;蔚来汽车;雷蛇;柔宇科技;汉富资本', '', '金证股份;欧普康视', 'ORANFLAG', '旭隆金业;旺润配资;金银策略;汇丰鸿利;云旗金服;策略资本;翻翻配资;中首上上策;震泰国际;亨达国际;稳赢策略;四川大宗;海牛金服;指南针大宗;盈策略;牛360配资;一股期权;维海配资;银华中期;博时国际期货;中泰期权;创利融;步步盈配资;中证期权;航心配资;鼎鑫金业;创盈金服;九五配资;鼎牛配资;亿配资;华瑞国际;鼎盛配资;艾德配资;百汇期权;点点金富通国际',
                   'sumtoken;fCOIN交易所', '', 'PPmoney;爱钱进;轻易贷;永银贷;一诚一贷;一人一贷', '东霖国际', '', '阿尔泰平台', '商房所;捷麦理财;易麦理财;砚下金融;俊掌柜;91飞猫;外快理财', 'DFC;东方城', '民生证券;中天金融;民金所;泛海控股;亚太财险',
                   '华远国际;易资配;速汇国际;汇丰联合;泛金国际;信投在线;EGML环球金融;创远世纪金融;东吴金服;豪根国际;AJPFX;中港国际;股融易;信邦策略;大彩配资;飞客在线;世纪金融;世纪集团金业控股有限公司;金景配资;中瑞财讯;西安环海陆港商品交易中心;方正国际;新源财经;50ETF;弘基金融;富通国际;恒信财富;日照大宗商品交易所;海慧通;洪富期货;创昇国际期货;海拓环球融资融券;安信金控',
                    '长颈鹿;MT', '中信资本;中信资本;凤凰智信;凤凰金融;东方资产',
                   '高胜投资;投行配资;九鼎金策;香港信诚资产;中鑫建投;帝锋金业;向上金服;银丰配资;粤友钱;策马财经;盈龙策略', # 2551
'人人贷;钱来也;帛扬集团', '零花钱;趣妙租;亚远科技', '一直牛',
'钱宝科技;招钱进宝;秒到账;钱宝', '苏宁体育;苏宁文创', '布尔金融;冠联投资家园.财富领航',
'益冠创投;金凤凰;煌萨投资;国富通;易钱汇;豆包金服;观金鉴股;浩然众筹;房金网;金豆包;91飞猫;朔漪众筹;穆金所;中信创;好牛创投;汇信聚业;黎明国际;易众网;吆鸡理财', '宜聚网;和信贷;恒易融;花虾金融;恒慧融;财富中国;道口贷;信融财富;东创在线;小赢理财;草根投资;信广立诚贷;聚财猫;善林金融;宁波东创投资管理有限公司;东创投资;多乐融;金瑞龙;财富星球;米庄理财;达人贷;信而富;洋钱罐;今日捷财;财猫;短融网;银豆网;恒昌;钱盆网;笑脸金融;云端金融;开鑫金服', '海航集团;航海创新', 'HIIFX海汇国际', '欧莱雅;嗨团团购', '红域短视频;火牛视频', '快易点;八里香', '陆金所;小牛在线;蜂融网;平安集团', '钱牛牛;麦子金服;互融宝;51人品;拓道金融;玖富;付融宝', '花生日记', '乾包网;津启宝;九药网;好友邦金服;点聚财;生意贷;聚财;金羊金服;东上金服;中航生意贷;邦金服', '和信贷;贯通金服;普汇云通;点融网;众筹平台;中融民信;酷盈网;点融;盎盎理财;众力金融', '积木盒子;全球金融;网贷;宜人贷;人人贷;巴巴汇', '中国保险;加油宝;同花顺财经;合时代', 'FCoin', '海拓环球;融资融券', '链上钱包;以太森林', '新橙分期;新口子;今日推荐', '金殿环球', '金包豆;“兄弟”车贷;昌和财富;昌久财富;聚利众筹;建军财富;聚汇天下;诚天财富;金贝贝;桑善金服;金致财富;絮东投资;建元投资;豆蔓智投;中智魔方;亦川资本;红丰智投;富捷金服;厚元投资;紫檩金融;鸿百佳投资;酷盈网;中仁财富;理财咖;益冠创投;华隆资产;涌集投资;聚金袋子;聚乐资本;博美投资;融创嘉诚;大盈投资;金苏在线', '金开贷',
                   '锦安财富;金色木棉;北京卡拉卡尔科技公司;ST德奥;意隆财富', '信邦策略;中赢国际;花旗;贵金属;恒牛策略;现货白银;金源财富;海利国际;嘉晟财富;澳大利亚证券及投资委员会(ASIC);新源财经;炒外汇英国金融;聚富策略;富通国际;速汇国际;白象国际;英国金融行为管理局(FCS);瑞士金融市场监督管理局(FINMA);速赢;新西兰金融市场管理局(FMA);外汇投资;恒利金业;汇丰;环球;优信外汇;新华富时;环球金融;汇融国际', '天安金交中心;厦金中心;厦金理财平台;天安(贵州省)互联网金融资产交易中心股份有限公司;北京鑫旺阔达投资有限公司;鑫旺投资;深圳市景腾瑞贸易发展有限公司;厦门国际金融资产交易中心有限公司', '移动钱包;玖富数科集团;玖富钱包',
'分期乐;玖富普惠;投哪网;小赢理财;微贷网;桔子理财;玖富钱包;网上赚钱;小赢钱包;东方证券;玖富;贷网;龙支付;宜人贷;麻袋财富;国信证券',
'分期乐;宜人贷;玖富普惠;投哪网;小赢理财;微贷网;桔子理财;玖富钱包;网上赚钱;小赢钱包;东方证券;玖富;龙支付;微贷;国信证券;麻袋财富',
'酒业链wnn', '中资信汇投资打点有限公司', '广发基金;大成基金;博时基金;鹏华基金;汇添富基金;jojomarkets', '汇置投资;挑战者;汇置财富',
'凯顿', '皇玛金融;中北选买;中远期货;创远世纪金融;海利国际;青岛西海岸;恩圣威ncy;新源财经;华远国际;聚富策略;AJPFX;富通国际;速汇国际;丰盛金融;艾利威国际;保诚国际;豪根国际;游资通;方正国际;恒信财富;粒子金融;恒利金业;泛金国际;优信外汇;创远世纪;世纪金融;汇融国际;中首投资;安信;中融金业',
'新华都;三江购物', '', '瑞波币;恒星币;万维币;通盛币;珍宝币;富豪币;万福币;吉祥币;视界链;农业链;天使链;流量魔石;金元币;西游链;高兴币;电能链;lmc柠檬币;kdc凯帝币;csc炭汇币;scc足球币;绿链;acc防伪币;fyb弗益币;汇择投资;正谦益;睿鑫宝;德爱社区;微韵文化;益路同行;山海经;融易通;特色三妹;至尊;阿川;星火草原;恩威商城;CNY金融互助;公益社区掌心众扶;友钱宝;友义宝;影子银行;智富宝;云支付|云付通;精神传销心灵培训;亿加互助;ICA;微韵文化;kdc凯帝币;流量魔石;云支付|云付通;fyb弗益币;智富宝;星火草原;万福币;德爱社区;公益社区掌心众扶;天使链;scc足球币;影子银行;lmc柠檬币;富豪币;睿鑫宝;友义宝;吉祥币;恩威商城;CNY金融互助;acc防伪币;正谦益;金元币;友钱宝;维卡币;ICA;绿链;特色三妹;恒星币;视界链;至尊;阿川;精神传销心灵培训;益路同行;csc炭汇币;农业链;珍宝币;融易通;汇择投资;通盛币;马克币;山海经;亿加互助;西游链;万维币;高兴币;电能链',
'坤吉国际;EQR;FDEX;贵州国际交易中心;广州西勃商品交易中心;创昇国际期货;点牛融资融券;信捷策略;众昇策略;时盛财汇;鼎点策略;钱盈配资;国金策略;嘉露国际;有富策略;东方汇盈;小金橘策略;红牛策略;真牛科技;中航江南期货;象屿期货;贝赢网;信诚资产;涨悦财金;海慧通;壹恒国际;鼎点策略;钱盈配资;顺通在线;国金策略;嘉露国际;迅视资管;期权专车;桑杰股权;花旗资本;BKB;九州金服;中盟国际;中浙金控海博利;创辉国际;海南大宗;诺安期货;路易泽;安信金控;百益策略;期货大赢家', '', '易金融;麻袋财富', '理理财', '立刷', '钱保姆;分秒金融;饭饭金服;掌悦理财;黎明国际;金汇微金;一点金库;金统贷;有融网', '零点矿工', '招财宝;香港安盛投连险;AsiaOne', '平安银行;江苏银行;蚂蚁财富;天天基金;陆金所;微交易;嘉合基金', '',
'觅信DEC', '积木盒子;钱来也;钱来也网络借贷;你我贷',
'诺德基金;中天证券;大同证券;山西证券', '广州承兴营销管理有限公司;诺亚财富',
'恒利金业;FusianGallant;香港富赢通;高盛亚太;菲特尔国际;粒子金融;格兰特;中泰期权;TRENDS;欧克斯;中首投资;优越投资;嘉晟财富;速汇国际;青岛西海岸;豪根国际;保卓国际;Morse;火币网;奥瑞国际;琥珀帝国;大赢家;星亘国际;富时罗素;嘉兰待;法纳金融;高盛亚太;大赢家期货;香港优越投资中心;中元天颐;中远期货;香港富盈;宝丰国际;火币网;CFEX;帕克金融;鼎和金控;火币;中浙金控;彭义昆维权;富通国际;世纪金业;ATTEX', '领航国际资本', '前金融;陆金所;链链金融;房金所;中金贷;津融贷;车赚', '马胜金融;普顿PTFX;普顿ptfx;聚宝金融;IGOFX;3M;PTPrutonMegaBerjangka;PrutonCapital', '沃尔克外汇;MMM金融互助;马胜金融;亨英集团;HIIFX海汇国际;EA智能交易;IGOFX;HYBX', '红威投资;金融网;富利宝;豫之兴资本;钱富通;竞优理财;丰鼎金融;真信汇通;金财富;智慧理财;聚财;全局金服;榕巨互金;信息网;天诚财富;星通财富;梵丰投资;国有投资;德众金融;吉盟财富;永利宝;亿企聚财;花橙金融;理财网;吉农投资;储信理财;火理财',
 '合创', '光子链PTN', '京东;苏宁;唯品会;淘宝;考拉;权健;华林', '京东;苏宁;唯品会;淘宝;考拉;权健;华林', '天猫;淘宝;京东;优品汇;全返通',
 '融创中国;摩根士丹利;德意志银行;农银国际', 'e租宝;钱宝网;江苏联宝', 'BitMEX', 'GEC',
 '夏商风信子;时福全球购;恒优国际;宝象商务中心;金淘惠源;沃洋优品;酩悦酒业;欧颂酒业;优传保税;跑街',
 '淘宝;京东;网易', '东方证券股份有限公司;花旗环球金融(亚洲)有限公司',
 '神州泰岳投资;米庄理财;神州泰岳;图腾贷;沃时贷;米庄理财', '', 'A股头条', '',
 '速汇国际;SpeedTrad;新华富时A50;德指DAX30;恒指HSI;泰达币', '速通宝vpay',
 '悟空理财;晋商贷;铜板街;泰然金融;玖富;融牛在线;可溯金融;微贷网',
 '比特币;莱特币;无限币;夸克币;泽塔币;烧烤币;隐形金条', '优乐商城;淘优乐',
 '米缸金融;富管家;鑫聚天利;富盈;金理财;宁富盈;天安(贵州省)互联网金融资产交易中心股份有限公司;北京航天浦盛科技中心;天安金交中心',
 '富管家;鑫聚-天利;天安金交所', '寒武创投;熠美投资', '红岭创投;PPmoney;乐享宝', '金融投资;风险投资;钱生钱',
 '莱次狗;摩拜链(MobikeChain);以太坊;遨游;共生币;CNHKC;CEC;ENE;共生币;遨游;摩拜链', '挖易',
 '玖富数科;友信智投;普惠金融;金融科技;宜人金科;友信金服', '雷达币;雷达钱包;雷达支付', '未来星球', 'jojomarkets',
'西投控股;西安经开城投;西安城投(集团);西安曲江文化;长安信托;西安金控',
'远特通信;远特喜牛;远特;喜牛', '华霖财富管理股份有限公司;华霖金服',
'钱宝;百川币;e租宝;圣商;中晋;e租宝;泛亚;鲜生友请', '伽满优;富友支付;乾包网',
'伽满优;富友支付;乾包网', '伽满优;富友支付;乾包网', '策略通;牛股盈;新纪元期货;花旗证券;广汇大宗商品交易中心;贵州西部农产品交易中心;众生策略;沃伦策;中阳期货;宏琳策略;股亿讯;莱赢宝;金盛商贸;创投大师;芬吉;牛来了;掌互通;鼎盈信投;沪深689策略;国金策略;神圣策略;益达商城;壹恒国际;复兴恒福;神谷策略;江苏百瑞赢;众达国际期货;纯手期货;天兴国际期货;国人期投;超人国际;融盛在线;众赢投顾;神圣策略;股易融;花旗证券;鸿运信投;众生策略;财创期选;劲牛期权;华信策略;中讯策略;创期国投;顺配宝;香港英联策略', 'UnWallet', '小九花花', '众赢;普惠金融;悟空理财;叮当贷;玖富钱包',
 '小象金融;响当当;百仁贷;公众理财;宜泉资本;信而富;99财富;一起理财;酷盈网;人人爱家', '小诸葛金服;芝麻宝金服;天农金融;有融网;狐小狸理财;城城理财;海星宝理财;易纳理财;普益金服;财富中国;小灰熊金服;红八财富;贷你盈;超人贷;台州贷;百金贷;浣熊理财;银号理财;伟民金服;啄米理财;麻宝金服;天农金融;有融网;狐小狸理财;城城理财;海星宝理财;易纳理财;普益金服;财富中国;小灰熊金服;红八财富;贷你盈;超人贷;台州贷;百金贷;浣熊理财;银号理财;伟民金服;啄米理财',
'', '', '中云国际;E路商城', '京东金融;京东理财;小白理财;年年盈;月月盈;金理财;天天盈;币基金', '猎金集团;猎金全民影视',
 '巨人理财;掌悦理财;一点金库;央金所;利利金服;领奇理财;投米乐;领奇理财;一点金库;利利金服;微米在线;掌悦理财;巨人理财;投米乐;微米在线', '宜信惠民投资管理', '以太云', '', '宁波甬坚网络科技有限公司;麦穗金服;钱内助;三金在线;民信金服;利民网;巨如众吧;抢钱通;金投手;壹万木投资;宁海县永坚混凝土有限公司;易麦理财',
 '金证科技;新大陆;兴业数金;工银科技;高伟达;国泰君安研究所;民生科技','仁远资本;贝米钱包;东泽汇顺发;人人贷', '啄米理财;快点理财;甬e贷;多米金融;壹佰金融;津启宝;利魔方;温州贷;招金猫;易贷在线', '富通环球投资;恒信环球投资;恒信国际;恒信贵金属;恒信集团', '五星基金;华安策略', '你我贷;玖富普惠;宜人贷;微贷网', '有信钱包;芝麻分贷款', '花生日记;菜鸟;云集微店;',
 '皮城金融;企查查;海宁民间融资服务中心;海宁皮城', '浙江谢志宇控股集团有限公司;杭州凯蓝汽车租赁有限公司', '蘑菇街;飞猪旅行;侠侣联盟;厦门侠网旅游服务有限公司;厦门侠网旅游服务有限公司', '道琼斯指数;平安证券;HDI', '海贝国际;IGOFX平台', '光大保德信鼎鑫基金;华泰期货',
'MoreToken钱包;Coinone;Tokenstore钱包;BossToken钱包;智能搬砖;BossToken;SecurityToken',
'玖富叮当贷;马上金融;招联好期贷;小鲨易贷', '天津银行;智圣金服;金融理财;全民理财;广州智圣大健康投资有限公司', '智圣金服;金融理财',
'红橙优选;微豪配资;恩圣威;MORSE;易信;新源财经;中北选买;嘉晟财富;恒利金业;ATTEX;优信外汇;速汇国际;威海中元天颐;AJPFX;中元天颐;中泰之星',
'信邦;中赢国际;白象国际;花旗;中瑞财讯;恒利金业;海利国际;嘉晟财富;新源财经优信外汇;环球金融;汇融国际;聚富策略;富通国际;速汇国际', '中信华睿;华安策略;福盛期权;杜德配资;WIRECADD;MALAFY;金田策略;;Helong和隆;银岛配资;世纪金业;鼎盈信投;信融期权;弘基金融;天臣配资;久联优配;致富配资;鼎泽配资;涵星配资;鑫配资;鼎盈信投;信邦策略;百益策略;安信金控;CFX圆汇;格林期货;鸿运信投;信邦策略;宏观策略;金多多配资;罗宾智投;信溢国际;弘基金融;万荣国际;多乾国际;合欣国际;EGML;环球金融;HATSKY;速达国际;中阳期货;丰讯凯国际FDEX',
'华远国际;撮合网;粒子金融;明道配资;长江期货;佳银融资融券;海南大宗商品交易中心;贵州国际商品交易中心;策略资本;稳赢策略;盈策略;川商联宗商品;外汇投资;天元策略;聚富策略;环海陆港;汇融国际;领航配资;新纪元;广州西勃商品交易中心;权金汇;东方财经;中远期货;诚信配资;方正国际;新源财经;艾利威;大连商品交易所;赛岳恒配资;弘基金融;创期国投;盛赢期服', '“慧盈”理财;“家和盈”理财;“增盈”理财', '普信金服APP', '投哪网;麻袋财富;东方证券;桔子理财;微贷网;国信证券;小赢理财;分期乐;宜人贷;小赢钱包', '中金珠宝', 'P2B;微金融;芒果金融', '', '',
'股王配资;DBC币;众融众投;新富金融;恒通国际;微交易;大東方国际商品交易集团;鑫汇环球;大東方国际商品交易集团;恒通国际微交易;DBC币;新富金融;股王配资;众融众投;鑫汇环球',
'中航期货;震泰国际;ainol艾诺;joso聚硕;tfyforex;国峰贵金属',
'plustoken', '亚马逊', 'brt房地产信托', '火币;okex', '嘉盛', '沃客', 'okex', '爱福瑞',
'云讯通;云数贸;五行币;善心汇;LCF项目;云联惠;星火草原;云指商城;世界华人联合会;世界云联;WV梦幻之旅;维卡币;万福币;二元期权;云梦生活;恒星币;摩根币;网络黄金;1040阳光工程;中绿资本;赛比安;K币商城;五化联盟;国通通讯网络电话;EGD网络黄金;万达复利理财;MFC币理财;微转动力;神州互联商城;绿藤理财;绿色世界理财;宝微商城;中晋系;马克币;富迪;万通奇迹;港润信贷;CNC九星;世界云联;沃客生活;天音网络;莱汇币;盛大华天;惠卡世纪;开心理财网;贝格邦BGB;FIS数字金库;SF共享金融;DGC共享币;易赚宝;丰果游天下;天狮集团;薪金融;MGN积分宝;光彩币;亿加互助;GemCoin(珍宝币);老妈乐'


                  ]  # 对应id的修正实体
id_list = train_df['id'].tolist()
label_list = train_df['unknownEntities'].tolist()
for i,idx in enumerate(id_list):
    if idx in false_get_id:
        label_list[i] = repair_id_label[false_get_id.index(idx)]
# 修复过程中漏了几个标签,在这里补上
label_list[2409] = '金融科技(Fintech)'
label_list[2479] = '玖富钱包;玖富数科集团;玖富钱包APP'
label_list[3596] = '盈盈理财;乾包网;臻理财;蜗牛在线'
train_df['unknownEntities'] = label_list
train_df = train_df[~train_df['unknownEntities'].isnull()]  # 删除空标签
train_df.to_csv('new_train_df.csv')
  • 数据集的划分
print('Train Set Size:',train_df.shape)
new_dev_df = train_df[4000:]
frames = [train_df[:2000],train_df[2001:4000]]
new_train_df = pd.concat(frames) #训练集
new_train_df = new_train_df.fillna('')
new_test_df = test_df[:] #测试集
new_test_df.to_csv('new_test_df.csv', encoding='utf-8', index=False)
print('Test Set Size:',test_df.shape)
  • 对长文本按照标点优先级进行切割
def _cut(sentence):
    new_sentence = []
    sen = []
    for i in sentence:
        if i in ['。', '!', '?', '?'] and len(sen) != 0:
            sen.append(i)
            new_sentence.append("".join(sen))
            sen = []
            continue
        sen.append(i)

    if len(new_sentence) <= 1: # 一句话超过max_seq_length且没有句号的,用","分割,再长的不考虑了。
        new_sentence = []
        sen = []
        for i in sentence:
            if i.split(' ')[0] in [',', ','] and len(sen) != 0:
                sen.append(i)
                new_sentence.append("".join(sen))
                sen = []
                continue
            sen.append(i)
    if len(sen) > 0:  # 若最后一句话无结尾标点,则加入这句话
        new_sentence.append("".join(sen))
    return new_sentence
def cut_test_set(text_list):
    cut_text_list = []
    cut_index_list = []
    for text in text_list:

        temp_cut_text_list = []
        text_agg = ''
        if len(text) < len_treshold:
            temp_cut_text_list.append(text)
        else:
            sentence_list = _cut(text)  # 一条数据被切分成多句话
            for sentence in sentence_list:
                if len(text_agg) + len(sentence) < len_treshold:
                    text_agg += sentence
                else:
                    temp_cut_text_list.append(text_agg)
                    text_agg = sentence
            temp_cut_text_list.append(text_agg)  # 加上最后一个句子

        cut_index_list.append(len(temp_cut_text_list))
        cut_text_list += temp_cut_text_list

    return cut_text_list, cut_index_list
def cut_train_and_dev_set(text_list, label_list):
    cut_text_list = []
    cut_index_list = []
    cut_label_list = []
    for i, text in enumerate(text_list):
        if label_list[i] != '':
            text_label_list = label_list[i].split(';')  # 获取该条数据的实体列表
            temp_cut_text_list = []
            temp_cut_label_list = []
            text_agg = ''
            if len(text) < len_treshold:
                temp_cut_text_list.append(text)
                temp_cut_label_list.append(label_list[i])
            else:

                sentence_list = _cut(text)  # 一条数据被切分成多句话

                for sentence in sentence_list:
                    if len(text_agg) + len(sentence) < len_treshold:
                        text_agg += sentence
                    else:
                        new_label = []  # 新构成的句子的标签列表
                        for label in text_label_list:
                            if label in text_agg and label != '':
                                new_label.append(label)

                        if len(new_label) > 0:
                            temp_cut_text_list.append(text_agg)
                            temp_cut_label_list.append(";".join(new_label))

                        text_agg = sentence
                # 加回最后一个句子
                new_label = []
                for label in text_label_list:
                    if label in text_agg and label != '':
                        new_label.append(label)
                if len(new_label) > 0:
                    temp_cut_text_list.append(text_agg)
                    temp_cut_label_list.append(";".join(new_label))

            cut_index_list.append(len(temp_cut_text_list))
            cut_text_list += temp_cut_text_list
            cut_label_list += temp_cut_label_list

    return cut_text_list, cut_index_list, cut_label_list
train_text_list = new_train_df['text'].tolist()
train_label_list = new_train_df['unknownEntities'].tolist()
train_id_list = new_train_df['id'].tolist()

dev_text_list = new_dev_df['text'].tolist()
dev_label_list = new_dev_df['unknownEntities'].tolist()

test_text_list = new_test_df['text'].tolist()
test_id_list = new_test_df['id'].tolist()
test_cut_text_list, cut_index_list = cut_test_set(test_text_list)
train_cut_text_list, train_cut_index_list ,train_cut_label_list = cut_train_and_dev_set(train_text_list,  train_label_list)
dev_cut_text_list, dev_cut_index_list, dev_cut_label_list = cut_train_and_dev_set(dev_text_list, dev_label_list)
  • 测试数据集切分是否正确
flag = True
for i,text in enumerate(train_cut_text_list):
    label_list = train_cut_label_list[i].split(";")
    for li in label_list:
        #标签不在text中,或者标签为空
        if li not in text:
            print(i)
            print(li)
            print(text)
            flag = False
            print()
            break
        if li == '':
            print(li)
            print(text)
            flag = False
            print()
if flag:
    print("训练集切分正确!")
else:
    print("训练集切分错误")
flag = True
for i, text in enumerate(dev_cut_text_list):
    label_list = dev_cut_label_list[i].split(';')
    for li in label_list:
        if li not in text:
            print(i)
            print(li)
            print(text)
            print()
            flag = False

if flag:
    print("验证集切分正确!")
else:
    print("验证集切分错误!")
  • 保存切分索引
cut_index_dict = {'cut_index_list': cut_index_list}
with open('cut_index_list.json', 'w') as f:
    json.dump(cut_index_dict, f, ensure_ascii=False) #将一个Python数据结构转换为JSON
dev_cut_index_dict = {'cut_index_list': dev_cut_index_list}
with open('dev_cut_index_list.json', 'w') as f:
    json.dump(dev_cut_index_dict, f, ensure_ascii=False) #将一个Python数据结构转换为JSON
train_dict = {'text': train_cut_text_list, 'unknownEntities': train_cut_label_list}
train_df = pd.DataFrame(train_dict)

dev_dict = {'text': dev_cut_text_list, 'unknownEntities': dev_cut_label_list}
dev_df = pd.DataFrame(dev_dict)

test_dict = {'text': test_cut_text_list}
test_df = pd.DataFrame(test_dict)

print('训练集:', train_df.shape)
print('验证集:', dev_df.shape)
print('测试集:', test_df.shape)
  • 将切分的数据转成BIO数据格式
with codecs.open('train.txt', 'w', encoding='utf-8') as up:
    for row in train_df.iloc[:].itertuples():#将DataFrame迭代为元祖。
        #print(row.text)
        text_lbl = row.text
        entitys = str(row.unknownEntities).split(';')
        for entity in entitys:
            text_lbl = text_lbl.replace(entity, 'Ё' + (len(entity) - 1) * 'Ж')
        for c1, c2 in zip(row.text, text_lbl):
            if c2 == 'Ё':
                up.write('{0} {1}\n'.format(c1, 'B-ORG'))
            elif c2 == 'Ж':
                up.write('{0} {1}\n'.format(c1, 'I-ORG'))
            else:
                up.write('{0} {1}\n'.format(c1, 'O'))
        up.write('\n')
 with codecs.open('dev.txt', 'w', encoding='utf-8') as up:
    for row in dev_df.iloc[:].itertuples():
        # print(row.unknownEntities)
        text_lbl = row.text
        entitys = str(row.unknownEntities).split(';')
        for entity in entitys:
            text_lbl = text_lbl.replace(entity, 'Ё' + (len(entity) - 1) * 'Ж')

        for c1, c2 in zip(row.text, text_lbl):
            if c2 == 'Ё':
                up.write('{0} {1}\n'.format(c1, 'B-ORG'))
            elif c2 == 'Ж':
                up.write('{0} {1}\n'.format(c1, 'I-ORG'))
            else:
                up.write('{0} {1}\n'.format(c1, 'O'))

        up.write('\n')
with codecs.open('test.txt', 'w', encoding='utf-8') as up:
    for row in test_df.iloc[:].itertuples():

        text_lbl = row.text
        for c1 in text_lbl:
            up.write('{0} {1}\n'.format(c1, 'O'))

        up.write('\n')

在机器学习中数据的质量直接决定了任务的天花板,而模型和算法只是无限的去接近这个天花板,本文对前期的数据处理进行了详细的介绍。最终在pycharm中运行结果如下图所示:
在这里插入图片描述
在本地data/process_data中生成了我们需要的文件:
在这里插入图片描述

4.2 数据迭代器

在将数据处理成标准的BIO标注格式后,需要使用数据迭代器将数据喂给模型中,虽然所一般情况下可以将很多数据直接读入内存中,但是对于一些较大的神经网络或晕训练模型时,数据量往往较大,我们需要分批次将数据送入后去的神经网络中。

  • 加载数据
def load_data(data_file):
   """
   读取BIO的数据
   :param file:
   :return:
   """
  • 读取语料和标记
def create_example(lines):
"""
读取语料和标签
: return examples
"""
  • 得到实例及标签
def get_examples(data_file):
   return create_example(
       load_data(data_file)
   )
def get_labels():
       return ["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "X","[CLS]","[SEP]", '']
  • 数据迭代器
class DataIterator:
   """
   数据迭代器
   """
   def __init__(self, batch_size, data_file, tokenizer, use_bert=False, seq_length=100, is_test=False,):
       self.data_file = data_file
       self.data = get_examples(data_file)   #获取 编码好的数据
       self.batch_size = batch_size
       self.use_bert = use_bert
       self.seq_length = seq_length
       self.num_records = len(self.data)
       self.all_tags = []
       self.idx = 0  # 数据索引
       self.all_idx = list(range(self.num_records))  # 全体数据索引
       self.is_test = is_test

       if not self.is_test:
           self.shuffle()
       self.tokenizer = tokenizer
       self.label_map = {}
       for (i, label) in enumerate(get_labels(), 1):
           self.label_map[label] = i
       self.unknow_tokens = self.get_unk_token()

       print(self.unknow_tokens)
       print(self.num_records)

   def get_unk_token(self):
       unknow_token = set()
       for example_idx in self.all_idx:
           textlist = self.data[example_idx].text.split(' ')

           for i, word in enumerate(textlist):
               token = self.tokenizer.tokenize(word)

               if '[UNK]' in token:
                   unknow_token.add(word)
       return unknow_token

   def convert_single_example(self, example_idx):  #构造一个batchsize的数据
       textlist = self.data[example_idx].text.split(' ')
       labellist = self.data[example_idx].label.split(' ')
       tokens = textlist  # 区分大小写
       labels = labellist

       if len(tokens) >= self.seq_length - 1:
           tokens = tokens[0:(self.seq_length - 2)]
           labels = labels[0:(self.seq_length - 2)]
       ntokens = []
       segment_ids = []
       label_ids = []
       ntokens.append("[CLS]")
       segment_ids.append(0)
       label_ids.append(self.label_map["[CLS]"])

       upper_letter = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R',
                       'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z',
                       ]
       for i, token in enumerate(tokens):
           if token in self.unknow_tokens and token not in upper_letter:
               token = '[UNK]'
               ntokens.append(token)  # 全部转换成小写, 方便BERT词典
           else:
               ntokens.append(token.lower())  # 全部转换成小写, 方便BERT词典
           segment_ids.append(0)
           label_ids.append(self.label_map[labels[i]])

       tokens = ["[CLS]"] + tokens + ["[SEP]"]   #编写BERT的词向量编码
       ntokens.append("[SEP]")

       segment_ids.append(0)
       label_ids.append(self.label_map["[SEP]"]) #加入BERT的段落编码

       input_ids = self.tokenizer.convert_tokens_to_ids(ntokens)  #判断输入的序列是否等于BERT的最大输入
       input_mask = [1] * len(input_ids)
       while len(input_ids) < self.seq_length :
           input_ids.append(0)
           input_mask.append(0)
           segment_ids.append(0)
           label_ids.append(0)
           ntokens.append("**NULL**")
           tokens.append("**NULL**")

       assert len(input_ids) == self.seq_length
       assert len(input_mask) == self.seq_length
       assert len(segment_ids) == self.seq_length
       assert len(label_ids) == self.seq_length
       assert len(tokens) == self.seq_length
       return input_ids, input_mask, segment_ids, label_ids, tokens

   def shuffle(self):
       np.random.shuffle(self.all_idx) #随机打乱

   def __iter__(self):
       return self

   def __next__(self):
       if self.idx >= self.num_records:  # 迭代停止条件
           self.idx = 0
           if self.is_test == False:
               self.shuffle()
           raise StopIteration

       input_ids_list = []
       input_mask_list = []
       segment_ids_list = []
       label_ids_list = []
       tokens_list = []

       num_tags = 0
       while num_tags < self.batch_size:  # 每次返回batch_size个数据
           idx = self.all_idx[self.idx]
           res = self.convert_single_example(idx)
           if res is None:
               self.idx += 1
               if self.idx >= self.num_records:
                   break
               continue
           input_ids, input_mask, segment_ids, label_ids, tokens = res

           # 一个batch的输入
           input_ids_list.append(input_ids)
           input_mask_list.append(input_mask)
           segment_ids_list.append(segment_ids)
           label_ids_list.append(label_ids)
           tokens_list.append(tokens)

           if self.use_bert:
               num_tags += 1

           self.idx += 1
           if self.idx >= self.num_records:
               break

       return input_ids_list, input_mask_list, segment_ids_list, label_ids_list, self.seq_length, tokens_list
  • BERT源码分词器tokenizer.vovab:
    在这里插入图片描述
  • 迭代器结果:

    数据迭代次梅批次的结果:
打印详细信息:
(1)input_ids_list
[[101, 1546, 671, 679, 1962, 4638, 2218, 3221, 3680, 1921, 6963, 6206, 5709, 1126, 1146, 7164, 1343, 678, 1296, 2798, 833, 3119, 4660, 511, 757, 5468, 5381, 6848, 2885, 1962, 7555, 4680, 2523, 7410, 8024, 5543, 1075, 782, 5549, 8024, 2868, 2245, 1730, 7339, 8024, 4937, 6617, 7555, 4680, 4696, 4638, 679, 1914, 8024, 6656, 4708, 7770, 782, 6624, 1403, 2330, 2292, 8024, 1914, 702, 3301, 1351, 1730, 7339, 1914, 3340, 2797, 5619, 8024, 671, 6629, 2456, 6863, 6498, 6756, 7339, 8024, 6375, 3581, 2094, 1730, 7339, 680, 872, 1762, 165, 148, 145, 3025, 2797, 2400, 6822, 106, 2190, 2970, 1730, 7339, 7566, 2193, 782, 117, 2544, 928, 131, 161, 143, 155, 126, 123, 127, 127, 129, 129, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 523, 3297, 3173, 1355, 2357, 524, 1912, 3726, 122, 122, 121, 5381, 123, 121, 122, 130, 2399, 124, 3299, 769, 3211, 1555, 2145, 6401, 5273, 7946, 3528, 3822, 2634, 671, 2157, 769, 3211, 1555, 3221, 1415, 7028, 6228, 2145, 2787, 4638, 1898, 7509, 8024, 2190, 2521, 2145, 2787, 4638, 2578, 2428, 1963, 862, 8024, 769, 3211, 1555, 4638, 3302, 1218, 2578, 2428, 8024, 738, 3221, 2832, 6598, 5442, 5440, 2175, 2398, 1378, 7478, 2382, 1068, 3800, 4638, 671, 3175, 7481, 8024, 2226, 1762, 123, 121, 122, 130, 2399, 124, 3299, 769, 3211, 1555, 2145, 6401, 5273, 7946, 3528, 511, 680, 123, 3299, 4685, 3683, 8024, 124, 3299, 769, 3211, 1555, 2145, 6401, 5273, 3528, 1469, 7946, 3528, 4638, 2832, 6401, 3144, 7030, 3146, 860, 1772, 3300, 2792, 1872, 1217, 511, 5273, 136, 136, 3528, 5273, 3528, 704, 8024, 2832, 6401, 1066, 123, 130, 6629, 8024, 680, 677, 702, 3299, 4685, 3683, 1872, 1217, 127, 6629, 8024, 4916, 3353, 1905, 4415, 123, 129, 6629, 8024, 1905, 4415, 4372, 711, 130, 127, 110, 511, 143, 166, 151, 162, 160, 143, 146, 147, 160, 510, 148, 166, 158, 160, 157, 510, 4078, 3828, 7770, 3726, 510, 154, 145, 149, 510, 4078, 3828, 4636, 3726, 5023, 769, 3211, 1555, 1772, 3221, 122, 6629, 2832, 6401, 8039, 143, 145, 167, 4921, 674, 6395, 1171, 510, 143, 154, 158, 143, 160, 151, 5687, 4886, 4448, 510, 147, 166, 156, 147, 161, 161, 510, 162, 151, 145, 153, 155, 151, 154, 154, 510, 166, 155, 510, 3211, 928, 147, 143, 161, 167, 155, 143, 160, 153, 147, 162, 161, 1772, 3221, 123, 6629, 2832, 6401, 8039, 148, 166, 162, 155, 2168, 2868, 124, 6629, 2832, 6401, 8039, 147, 151, 149, 150, 162, 145, 143, 158, 4078, 3828, 3211, 3726, 125, 6629, 2832, 6401, 8024, 684, 6963, 2533, 1168, 4916, 3353, 1905, 4415, 511, 6821, 763, 769, 3211, 1555, 4638, 2832, 6401, 4916, 3353, 1905, 4415, 3126, 4372, 6809, 1168, 122, 121, 121, 110, 511, 151, 145, 155, 143, 160, 153, 147, 162, 161, 2832, 6401, 125, 6629, 8024, 4916, 3353, 1905, 4415, 124, 6629, 8024, 4916, 3353, 1905, 4415, 3126, 4372, 6809, 1168, 128, 126, 110, 511, 7946, 136, 136, 136, 3528, 123, 121, 122, 130, 2399, 124, 3299, 7946, 3528, 2832, 6401, 1066, 122, 130, 6629, 8024, 680, 677, 702, 3299, 4685, 3683, 1872, 1217, 749, 128, 6629, 8024, 2867, 5318, 2772, 3867, 3353, 1905, 4415, 122, 128, 6629, 511, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
(2)input_mask_list
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
(3)segment_ids_list
[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
(4)label_ids_list
[[9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [9, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]
(5) self.seq_length
512
(6)tokens_list
[['[CLS]', '唯', '一', '不', '好', '的', '就', '是', '每', '天', '都', '要', '花', '几', '分', '钟', '去', '下', '单', '才', '会', '收', '益', '。', '互', '联', '网', '选', '择', '好', '项', '目', '很', '难', ',', '能', '养', '人', '脉', ',', '拓', '展', '团', '队', ',', '稳', '赢', '项', '目', '真', '的', '不', '多', ',', '跟', '着', '高', '人', '走', '向', '巅', '峰', ',', '多', '个', '朋', '友', '团', '队', '多', '条', '手', '臂', ',', '一', '起', '建', '造', '豪', '车', '队', ',', '让', '橙', '子', '团', '队', '与', '你', '在', 'W', 'F', 'C', '携', '手', '并', '进', '!', '对', '接', '团', '队', '领', '导', '人', ',', '微', '信', ':', 's', 'a', 'm', '5', '2', '6', '6', '8', '8', '[SEP]', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**'], ['[CLS]', '【', '最', '新', '发', '布', '】', '外', '汇', '1', '1', '0', '网', '2', '0', '1', '9', '年', '3', '月', '交', '易', '商', '客', '诉', '红', '黑', '榜', '洞', '悉', '一', '家', '交', '易', '商', '是', '否', '重', '视', '客', '户', '的', '声', '音', ',', '对', '待', '客', '户', '的', '态', '度', '如', '何', ',', '交', '易', '商', '的', '服', '务', '态', '度', ',', '也', '是', '投', '资', '者', '考', '察', '平', '台', '非', '常', '关', '注', '的', '一', '方', '面', ',', '尽', '在', '2', '0', '1', '9', '年', '3', '月', '交', '易', '商', '客', '诉', '红', '黑', '榜', '。', '与', '2', '月', '相', '比', ',', '3', '月', '交', '易', '商', '客', '诉', '红', '榜', '和', '黑', '榜', '的', '投', '诉', '数', '量', '整', '体', '均', '有', '所', '增', '加', '。', '红', '?', '?', '榜', '红', '榜', '中', ',', '投', '诉', '共', '2', '9', '起', ',', '与', '上', '个', '月', '相', '比', '增', '加', '6', '起', ',', '积', '极', '处', '理', '2', '8', '起', ',', '处', '理', '率', '为', '9', '6', '%', '。', 'A', 'x', 'i', 'T', 'r', 'a', 'd', 'e', 'r', '、', 'F', 'X', 'P', 'R', 'O', '、', '澳', '洲', '高', '汇', '、', 'L', 'C', 'G', '、', '澳', '洲', '百', '汇', '等', '交', '易', '商', '均', '是', '1', '起', '投', '诉', ';', 'A', 'C', 'Y', '稀', '万', '证', '券', '、', 'A', 'l', 'p', 'a', 'r', 'i', '艾', '福', '瑞', '、', 'E', 'x', 'n', 'e', 's', 's', '、', 'T', 'i', 'c', 'k', 'm', 'i', 'l', 'l', '、', 'X', 'M', '、', '易', '信', 'e', 'a', 's', 'y', 'M', 'a', 'r', 'k', 'e', 't', 's', '均', '是', '2', '起', '投', '诉', ';', 'F', 'X', 'T', 'M', '富', '拓', '3', '起', '投', '诉', ';', 'E', 'i', 'g', 'h', 't', 'C', 'a', 'p', '澳', '洲', '易', '汇', '4', '起', '投', '诉', ',', '且', '都', '得', '到', '积', '极', '处', '理', '。', '这', '些', '交', '易', '商', '的', '投', '诉', '积', '极', '处', '理', '效', '率', '达', '到', '1', '0', '0', '%', '。', 'I', 'C', 'M', 'a', 'r', 'k', 'e', 't', 's', '投', '诉', '4', '起', ',', '积', '极', '处', '理', '3', '起', ',', '积', '极', '处', '理', '效', '率', '达', '到', '7', '5', '%', '。', '黑', '?', '?', '?', '榜', '2', '0', '1', '9', '年', '3', '月', '黑', '榜', '投', '诉', '共', '1', '9', '起', ',', '与', '上', '个', '月', '相', '比', '增', '加', '了', '7', '起', ',', '拒', '绝', '或', '消', '极', '处', '理', '1', '7', '起', '。', '[SEP]', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**', '**NULL**']]

4.2模型构建

NER一直是NLP领域中的研究热点,从早期基于词典和规则的方法,到传统机器学习的方法,到近年来基于深度学习的方法,NER研究进展的大概趋势大致如下图所示:
在这里插入图片描述
当时的源代码作者尝试了BERT+BILSTM+CRF、BERT+IDCNN+CRF、动态权重融合BERT+IDCNN/BILSTM+CRF。

(1)BERT+BILSTM+CRF
在预训练模型BERT未出现之前,BILSTM-CRF就是一种比较流行的命名实体识别框架,这里将BERT的token向量(即含有上下文信息的词向量)喂到BILSTM中,BILSTM可以进一步提取特征,CRF可以更好的学习到标签之间的约束关系。网络模型一般输出的是一个句子的每个字的分类,不能够较好的考虑该字左右两边的结果,比如小明,小(B-PER)明(I-PER),可能会预测成小(B-PER)明(B-PER),而CRF可以避免这种情况,比如B-PER后面只能是I-PER或者O类标签。感兴趣的话可以参考这篇文章 博客.
BERT+BILSTM+CRF的结构如下图所示:
在这里插入图片描述
(2)BERT+IDCNN+CRF
IDCNN(Iterated Dilated Convolutional Neural Networks)迭代空洞卷积神经网络,和传统的CNN相比,IDCNN可以较好的捕捉更长的上下文信息,和时序模型LSTM相比,IDCNN可以实行并行化运算,大大提高了模型的运算速度,这里不再对 IDCNN模型进行赘述,感兴趣的话可以参考以下几篇文章:
【1】IDCNN for NER论文阅读笔记: 链接.
【2】如何理解空洞卷积(dilated convolution): 链接1.
【3】如何理解空洞卷积(dilated convolution): 链接2.
BERT+IDCNN+CRF的结果如下图所示:
在这里插入图片描述
本文采用的是当时IDCNN作者默认的模型参数,各层的dilation为[1,1,2],卷积核参数为3x3
(3)动态融合的BERT+IDCNN+CRF
BERT的每一层编码对文本都有着不同的理解,本文将BERT的12层编码赋予了不同的权重,初始化公式按照如下公式进行初始化:
a i = D e n s e u n i t = 1 ( r e p r e s e n t i ) a_i=Dense_{unit=1}(represent_i) ai=Denseunit=1(representi)
然后通过训练来确定权重,并将每一层的表示进行加权平均,在后接一层全连接层降维至512维(BERT可以接收的最大序列长度),公式如下:
o u t p u t = D e n s e u n i t = 512 ( ∑ i = 1 n a i ⋅ r e p r e s e n t i ) output = Dense_{unit=512}(\sum_{i=1}^{n}a_i\cdot represent_i) output=Denseunit=512(i=1nairepresenti)
式中represent i 为BERT每一层输出的表示,ai 为权重BERT每一层表示的权重值。
动态融合的BERT结构图如下图所示:
在这里插入图片描述
后续将512维的向量输入BILSTM/IDCNN-CRF模型即可。
(4)代码实现

-动态融合BERT和原生BERT代码:
【注意】这里的BERT代码直接调用bert_modeling库(该库来源于官方库)

    def bert_embed(self, bert_init=True):
        """
        读取BERT的TF模型
        :param bert_init:
        :return:
        """
        bert_config_file = self.config.bert_config_file
        bert_config = BertConfig.from_json_file(bert_config_file)
        # batch_size, max_seq_length = get_shape_list(self.input_x_word)
        # bert_mask = tf.pad(self.input_mask, [[0, 0], [2, 0]], constant_values=1)  # tensor左边填充2列
        model = BertModel(
            config=bert_config,
            is_training=self.is_training,  # 微调
            input_ids=self.input_x_word,
            input_mask=self.input_mask,
            token_type_ids=None,
            use_one_hot_embeddings=False)

        layer_logits = []
        for i, layer in enumerate(model.all_encoder_layers):
            layer_logits.append(
                tf.layers.dense(
                    layer, 1,
                    kernel_initializer=tf.truncated_normal_initializer(stddev=0.02),
                    name="layer_logit%d" % i
                )
            )

        layer_logits = tf.concat(layer_logits, axis=2)  # 第三维度拼接
        layer_dist = tf.nn.softmax(layer_logits)
        seq_out = tf.concat([tf.expand_dims(x, axis=2) for x in model.all_encoder_layers], axis=2)
        pooled_output = tf.matmul(tf.expand_dims(layer_dist, axis=2), seq_out)
        pooled_output = tf.squeeze(pooled_output, axis=2)
        pooled_layer = pooled_output
        # char_bert_outputs = pooled_laRERyer[:, 1: max_seq_length - 1, :]  # [batch_size, seq_length, embedding_size]
        char_bert_outputs = pooled_layer

        if self.config.use_origin_bert:
            final_hidden_states = model.get_sequence_output()  # 原生bert
            self.config.embed_dense_dim = 768
        else:
            final_hidden_states = char_bert_outputs  # 多层融合bert
            self.config.embed_dense_dim = 512

        tvars = tf.trainable_variables()
        init_checkpoint = self.config.bert_file  # './chinese_L-12_H-768_A-12/bert_model.ckpt'
        assignment_map, initialized_variable_names = get_assignment_map_from_checkpoint(tvars, init_checkpoint)
        if bert_init:
            tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

        tf.logging.info("**** Trainable Variables ****")
        for var in tvars:
            init_string = ""
            if var.name in initialized_variable_names:
                init_string = ", *INIT_FROM_CKPT*"
            print("  name = {}, shape = {}{}".format(var.name, var.shape, init_string))
        print('init bert from checkpoint: {}'.format(init_checkpoint))
        return final_hidden_states
  • IDCNN代码:
    def IDCNN_layer(self, model_inputs, name=None):
        """
        :param idcnn_inputs: [batch_size, num_steps, emb_size]
        :return: [batch_size, num_steps, cnn_output_width]
        """
        model_inputs = tf.expand_dims(model_inputs, 1)
        with tf.variable_scope("idcnn" if not name else name):
            shape = [1, self.filter_width, self.embedding_dim,
                     self.num_filter]
            print(shape)
            filter_weights = tf.get_variable(
                "idcnn_filter",
                shape=[1, self.filter_width, self.embedding_dim, self.num_filter],
                initializer=self.initializer
            )

            layerInput = tf.nn.conv2d(model_inputs,
                                      filter_weights,
                                      strides=[1, 1, 1, 1],
                                      padding="SAME",
                                      name="init_layer")
            finalOutFromLayers = []
            totalWidthForLastDim = 0
            for j in range(self.repeat_times):
                for i in range(len(self.layers)):
                    dilation = self.layers[i]['dilation']
                    isLast = True if i == (len(self.layers) - 1) else False
                    with tf.variable_scope("atrous-conv-layer-%d" % i,
                                           reuse=tf.AUTO_REUSE):
                        w = tf.get_variable(
                            "filterW",
                            shape=[1, self.filter_width, self.num_filter,
                                   self.num_filter],
                            initializer=tf.contrib.layers.xavier_initializer())
                        b = tf.get_variable("filterB", shape=[self.num_filter])
                        conv = tf.nn.atrous_conv2d(layerInput,
                                                   w,
                                                   rate=dilation,
                                                   padding="SAME")
                        conv = tf.nn.bias_add(conv, b)
                        conv = tf.nn.relu(conv)
                        if isLast:
                            finalOutFromLayers.append(conv)
                            totalWidthForLastDim += self.num_filter
                        layerInput = conv
            finalOut = tf.concat(axis=3, values=finalOutFromLayers)
            keepProb = tf.cond(self.is_training, lambda: 0.8, lambda: 1.0)
            # keepProb = 1.0 if reuse else 0.5
            finalOut = tf.nn.dropout(finalOut, keepProb)

            finalOut = tf.squeeze(finalOut, [1])
            finalOut = tf.reshape(finalOut, [-1, totalWidthForLastDim])
            self.cnn_output_width = totalWidthForLastDim
            return finalOut

    def project_layer_idcnn(self, idcnn_outputs, name=None):
        """
        :param lstm_outputs: [batch_size, num_steps, emb_size]
        :return: [batch_size, num_steps, num_tags]
        """
        with tf.name_scope("project" if not name else name):
            # project to score of tags
            with tf.name_scope("logits"):
                W = tf.get_variable("PLW", shape=[self.cnn_output_width, self.relation_num],
                                    dtype=tf.float32, initializer=self.initializer)

                b = tf.get_variable("PLb", initializer=tf.constant(0.001, shape=[self.relation_num]))

                pred = tf.nn.xw_plus_b(idcnn_outputs, W, b)

            return tf.reshape(pred, [-1, self.num_steps, self.relation_num], name='pred_logits')
  • BILSTM代码
    def biLSTM_layer(self, lstm_inputs, lstm_dim, lengths, name=None):
        """
        :param lstm_inputs: [batch_size, num_steps, emb_size]
        :return: [batch_size, num_steps, 2*lstm_dim]
        """
        with tf.name_scope("char_BiLSTM" if not name else name):
            lstm_cell = {}
            for direction in ["forward", "backward"]:
                with tf.name_scope(direction):
                    lstm_cell[direction] = rnn.CoupledInputForgetGateLSTMCell(
                        lstm_dim,
                        use_peepholes=True,
                        initializer=self.initializer,
                        state_is_tuple=True)
            outputs, final_states = tf.nn.bidirectional_dynamic_rnn(
                lstm_cell["forward"],
                lstm_cell["backward"],
                lstm_inputs,
                dtype=tf.float32,
                sequence_length=lengths)
        return tf.concat(outputs, axis=2)

    def project_layer(self, lstm_outputs, name=None):
        """
        hidden layer between lstm layer and logits
        :param lstm_outputs: [batch_size, num_steps, emb_size]
        :return: [batch_size, num_steps, num_tags]
        """
        with tf.name_scope("project" if not name else name):
            with tf.name_scope("hidden"):
                W = tf.get_variable("HW", shape=[self.lstm_dim * 2, self.lstm_dim],
                                    dtype=tf.float32, initializer=self.initializer)

                b = tf.get_variable("Hb", shape=[self.lstm_dim], dtype=tf.float32,
                                    initializer=tf.zeros_initializer())
                output = tf.reshape(lstm_outputs, shape=[-1, self.lstm_dim * 2])
                hidden = tf.tanh(tf.nn.xw_plus_b(output, W, b))

            # project to score of tags
            with tf.name_scope("logits"):
                W = tf.get_variable("LW", shape=[self.lstm_dim, self.relation_num],
                                    dtype=tf.float32, initializer=self.initializer)

                b = tf.get_variable("Lb", shape=[self.relation_num], dtype=tf.float32,
                                    initializer=tf.zeros_initializer())

                pred = tf.nn.xw_plus_b(hidden, W, b)

            return tf.reshape(pred, [-1, self.num_steps, self.relation_num], name='pred_logits')

  • CRF层代码(tensorflow.contrib.crf库中的crf_log_likelihood):
    def loss_layer(self, project_logits, lengths, name=None):
        """
        计算CRF的loss
        :param project_logits: [1, num_steps, num_tags]
        :return: scalar loss
        """
        with tf.name_scope("crf_loss" if not name else name):
            small = -1000.0
            # pad logits for crf loss
            start_logits = tf.concat(
                [small * tf.ones(shape=[self.batch_size, 1, self.relation_num]), tf.zeros(shape=[self.batch_size, 1, 1])],
                axis=-1)
            pad_logits = tf.cast(small * tf.ones([self.batch_size, self.num_steps, 1]), tf.float32)
            logits = tf.concat([project_logits, pad_logits], axis=-1)
            logits = tf.concat([start_logits, logits], axis=1)
            targets = tf.concat(
                [tf.cast(self.relation_num * tf.ones([self.batch_size, 1]), tf.int32), self.input_relation], axis=-1)

            self.trans = tf.get_variable(
                name="transitions",
                shape=[self.relation_num + 1, self.relation_num + 1],  # 1
                # shape=[self.relation_num, self.relation_num],  # 1
                initializer=self.initializer)
            log_likelihood, self.trans = crf_log_likelihood(
                inputs=logits,
                tag_indices=targets,
                # tag_indices=self.input_relation,
                transition_params=self.trans,
                # sequence_lengths=lengths
                sequence_lengths=lengths + 1
            )  # + 1
            return tf.reduce_mean(-log_likelihood, name='loss')

4.3模型训练

模型的训练模型参数微调和模型保存。

  • 模型的训练
def train(train_iter, test_iter, config):
    graph = tf.Graph()
    with graph.as_default():
        session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        session_conf.gpu_options.allow_growth = True
        session = tf.Session(config=session_conf)
        with session.as_default():
            model = Model(config)  # 读取模型结构图

            # 超参数设置
            global_step = tf.Variable(0, name='step', trainable=False)
            learning_rate = tf.train.exponential_decay(config.learning_rate, global_step, config.decay_step,
                                                       config.decay_rate, staircase=True)

            normal_optimizer = tf.train.AdamOptimizer(learning_rate)  # 下接结构的学习率

            all_variables = graph.get_collection('trainable_variables')
            word2vec_var_list = [x for x in all_variables if 'bert' in x.name]  # BERT的参数
            normal_var_list = [x for x in all_variables if 'bert' not in x.name]  # 下接结构的参数
            print('bert train variable num: {}'.format(len(word2vec_var_list)))
            print('normal train variable num: {}'.format(len(normal_var_list)))
            normal_op = normal_optimizer.minimize(model.loss, global_step=global_step, var_list=normal_var_list)
            num_batch = int(train_iter.num_records / config.batch_size * config.train_epoch)
            embed_step = tf.Variable(0, name='step', trainable=False)
            if word2vec_var_list:  # 对BERT微调
                print('word2vec trainable!!')
                word2vec_op, embed_learning_rate, embed_step = create_optimizer(
                    model.loss, config.embed_learning_rate, num_train_steps=num_batch,
                    num_warmup_steps=int(num_batch * 0.05) , use_tpu=False ,  variable_list=word2vec_var_list
                )

                train_op = tf.group(normal_op, word2vec_op)  # 组装BERT与下接结构参数
            else:
                train_op = normal_op

            timestamp = str(int(time.time()))
            out_dir = os.path.abspath(
                os.path.join(config.model_dir, "runs_" + str(gpu_id), timestamp))
            if not os.path.exists(out_dir):
                os.makedirs(out_dir)
            with open(out_dir + '/' + 'config.json', 'w', encoding='utf-8') as file:
                json.dump(config.__dict__, file)
            print("Writing to {}\n".format(out_dir))

            saver = tf.train.Saver(tf.global_variables(), max_to_keep=config.num_checkpoints)
            if config.continue_training:
                print('recover from: {}'.format(config.checkpoint_path))
                saver.restore(session, config.checkpoint_path)
            else:
                session.run(tf.global_variables_initializer())
            cum_step = 0
            """
            在config.py可以设置了23~4个epoch的时候,便可以停止了,保存几个模型,
            再通过check_F1.py来查看每次训练得到的最高F1模型,取最优模型进行预测。
            """
            for i in range(config.train_epoch):  # 训练
                for input_ids_list, input_mask_list, segment_ids_list, label_ids_list, seq_length, tokens_list in tqdm.tqdm(
                        train_iter):

                    feed_dict = {
                        model.input_x_word: input_ids_list,
                        model.input_mask: input_mask_list,
                        model.input_relation: label_ids_list,
                        model.input_x_len: seq_length,

                        model.keep_prob: config.keep_prob,
                        model.is_training: True,
                    }

                    _, step, _, loss, lr = session.run(
                            fetches=[train_op,
                                     global_step,
                                     embed_step,
                                     model.loss,
                                     learning_rate
                                     ],
                            feed_dict=feed_dict)


                    if cum_step % 10 == 0:
                        format_str = 'step {}, loss {:.4f} lr {:.5f}'
                        print(
                            format_str.format(
                                step, loss, lr)
                        )
                    cum_step += 1

                P, R = set_test(model, test_iter, session)
                F = 2 * P * R / (P + R)
                print('dev set : step_{},precision_{},recall_{}'.format(cum_step, P, R))
                if F > 0:  # 保存F1大于0的模型
                    saver.save(session, os.path.join(out_dir, 'model_{:.4f}_{:.4f}'.format(P, R)),
                               global_step=step)
  • 使用验证集,验证每个epoch的效果,并保存模型
def set_test(model, test_iter, session):

    if not test_iter.is_test:
        test_iter.is_test = True

    y_pred_list = []
    y_true_list = []
    ldct_list_tokens = []
    for input_ids_list, input_mask_list, segment_ids_list, label_ids_list, seq_length, tokens_list in tqdm.tqdm(
            test_iter):

        feed_dict = {
            model.input_x_word: input_ids_list,
            model.input_x_len: seq_length,
            model.input_relation: label_ids_list,
            model.input_mask: input_mask_list,

            model.keep_prob: 1,
            model.is_training: False,
        }

        lengths, logits, trans = session.run(
            fetches=[model.lengths, model.logits, model.trans],
            feed_dict=feed_dict
        )

        predict = decode(logits, lengths, trans)
        y_pred_list.append(predict)
        y_true_list.append(label_ids_list)
        ldct_list_tokens.append(tokens_list)


    ldct_list_tokens = np.concatenate(ldct_list_tokens)
    ldct_list_text = []
    for tokens in ldct_list_tokens:
        text = "".join(tokens)
        ldct_list_text.append(text)

    # 获取验证集文本及其标签
    y_pred_list, y_pred_label_list = get_text_and_label(ldct_list_tokens, y_pred_list)
    y_true_list, y_true_label_list = get_text_and_label(ldct_list_tokens, y_true_list)

    print(len(y_pred_label_list))
    print(len(y_true_label_list))

    dict_data = {
        'y_true_label': y_true_label_list,
        'y_pred_label': y_pred_label_list,
        'y_pred_text': ldct_list_text
    }
    df = pd.DataFrame(dict_data)
    precision, recall, f1 = get_P_R_F(df)

    print('precision: {}, recall {}, f1 {}'.format(precision, recall, f1))

    return precision, recall

训练过程的结果如下图所示:在这里插入图片描述
本次实验使用的是6G的显卡,batch_size设置为2,训练完一个epoch大概花费了两个小时左右,最终会在model/runs_0路径下生成训练好的模型:
在这里插入图片描述

4.4模型预测

先通过运行check_F1.py查看最优的预测模型,而后读取模型,再通过之前在model.py给每一个变量设置的变量名得到我们需要的变量,这样做的好处是无需重新构建模型图,只要抽取我们预测所需要的变量即可完成任务,一定程度上对我们的模型代码进行了加密。最后我们对测试集进行预测,并保存所需的预测概率。预测的代码结构和训练的代码相差不大

def get_session(checkpoint_path):
    graph = tf.Graph()

    with graph.as_default():
        session_conf = tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)
        session_conf.gpu_options.allow_growth = True
        session = tf.Session(config=session_conf)
        with session.as_default():
            # Load the saved meta graph and restore variables
            try:
                saver = tf.train.import_meta_graph("{}.meta".format(checkpoint_path))
            except OSError:
                saver = tf.train.import_meta_graph("{}.ckpt.meta".format(checkpoint_path))
            saver.restore(session, checkpoint_path)

            _input_x = graph.get_operation_by_name("input_x_word").outputs[0]
            _input_x_len = graph.get_operation_by_name("input_x_len").outputs[0]
            _input_mask = graph.get_operation_by_name("input_mask").outputs[0]
            _input_relation = graph.get_operation_by_name("input_relation").outputs[0]
            _keep_ratio = graph.get_operation_by_name('dropout_keep_prob').outputs[0]
            _is_training = graph.get_operation_by_name('is_training').outputs[0]


            used = tf.sign(tf.abs(_input_x))
            length = tf.reduce_sum(used, reduction_indices=1)
            lengths = tf.cast(length, tf.int32)
            logits = graph.get_operation_by_name('project/pred_logits').outputs[0]

            trans = graph.get_operation_by_name('transitions').outputs[0]

            def run_predict(feed_dict):
                return session.run([logits, lengths, trans], feed_dict)

    print('recover from: {}'.format(checkpoint_path))
    return run_predict, (_input_x, _input_x_len, _input_mask, _input_relation, _keep_ratio, _is_training)

预测的结果如下图所示,后续需要对该文件进行后处理:
在这里插入图片描述
会在该路径下生成测试文件:
在这里插入图片描述

4.5预测结果后处理

  • ensemble/ensemble.py
    将predict.py生成的概率文件复原成文字结果。其中,remove_list存放的是predict.py生成的概率文件名。如果我们想要还原某个模型的文字结果,直接注释该模型的概率文件即可。当然,vote_ensemble函数也可以对多个模型的概率文件进行还原,相当于是对多个模型的出来的标签进行数字投票,并还原成文字结果。不过笔者在竞赛时统计错误发现数字投票会截断实体,如下图所示。为此,笔者将数字投票改成了对所有单模的文字结果进行投票,我们通过设定阈值,统计每一条数据的预测实体在所有模型的出现次数,当实体出现次数大于阈值时,则认为该实体是未知实体,将其保留。笔者通过这个方法提高了两个百分点的成绩。
def vote_ensemble(path, dataset, output_path, remove_list):
   single_model_list = [x for x in os.listdir(path) if dataset + '_result_detail' in x]
   print('ensemble from file: ')
   for file_name in single_model_list:
       print(file_name)

   pred_list = OrderedDict()
   ldct_list = []
   text_index = -1  # 保证加入的ldct不是ernie模型的
   for index, file in enumerate(single_model_list):
       if file not in remove_list:  # 预测所有模型
           text_index = index
           print(index)
           print('Text File: ', file)
           break
   print('Ensembling.....')
   for index, file in enumerate(single_model_list):

       if file in remove_list:
           # print('remove file: ', file)
           continue
       print('Ensemble file:', file)
       with open(path + file) as f:
           for i, line in tqdm(enumerate(f.readlines())):
               item = json.loads(line)

               if i not in pred_list:
                   pred_list[i] = []
               pred_list[i].append(item['pred'])
               if index == text_index:
                   ldct_list.append(item['ldct_list'])


   print(len(pred_list))
   print(len(ldct_list))
   y_pred_list = []
   print('Getting Result.....')
   for key in tqdm(pred_list.keys()):
       pred_key = np.concatenate(pred_list[key]) # 3维
       j = 0
       temp_list = []
       for i in range(config.batch_size):
           temp = []
           while True:
               try:
                   temp.append(pred_key[j])
                   j += config.batch_size
               except:
                   j = 0
                   j += i + 1
                   break

           temp_T = np.array(temp).T  # 转置
           pred = []
           for line in temp_T:
               pred.append(np.argmax(np.bincount(line)))  # 找出列表中出现次数最多的值
           temp_list.append(pred)
       y_pred_list.append(temp_list)

   ldct_list_tokens = np.concatenate(ldct_list)
   # print(ldct_list)
   ldct_list_text = []
   for tokens in tqdm(ldct_list_tokens):
       text = "".join(tokens)
       ldct_list_text.append(text)
   # 测试集
   print(len(ldct_list_tokens))
   y_pred_list, y_pred_label_list = get_text_and_label(ldct_list_tokens, y_pred_list)

   print(len(y_pred_label_list))
   dict_data = {
       'y_pred_label_list': y_pred_label_list,
       'ldct_list_text': ldct_list_text,
   }
   df = pd.DataFrame(dict_data)
   df = df.fillna("0")
   df.to_csv(output_path + 'test_result.csv', encoding='utf-8')
def score_average_ensemble(path, dataset, output_path, remove_list):

    single_model_list = [x for x in os.listdir(path) if dataset + '_result_detail' in x]
    print('ensemble from file: ', len(single_model_list))
    for file_name in single_model_list:
        print(file_name)
    logits_list = OrderedDict()
    trans_list = OrderedDict()
    lengths_list = OrderedDict()
    ldct_list = []

    text_index = -1
    for index, file in enumerate(single_model_list):
        if file not in remove_list:  # 预测所有模型
            text_index = index
            print('Text File: ', file)
            print(text_index)
            break

    for index, file in enumerate(single_model_list):
        if file in remove_list:
            print('remove file: ', file)
            continue
        with open(path + file) as f:
            for i, line in tqdm(enumerate(f.readlines())):
                item = json.loads(line)

                if i not in logits_list:
                    logits_list[i] = []
                    trans_list[i] = []
                    lengths_list[i] = []

                logits_list[i].append(item['logit'])
                trans_list[i].append(item['trans'])
                lengths_list[i].append(item['lengths'])
                if index == text_index:
                    ldct_list.append(item['ldct_list'])

    y_pred_list = []
    for key in tqdm(logits_list.keys()):

        logits_key = logits_list[key]
        logits_key = np.mean(logits_key, axis=0)

        trans_key = np.array(trans_list[key])
        trans_key = np.mean(trans_key, axis=0)

        lengths_key = np.array(lengths_list[key])
        lengths_key = np.mean(lengths_key, axis=0).astype(int)

        pred = decode(logits_key, lengths_key, trans_key)
        y_pred_list.append(pred)

    ldct_list_tokens = np.concatenate(ldct_list)
    ldct_list_text = []

    for tokens in tqdm(ldct_list_tokens):
        text = "".join(tokens)
        ldct_list_text.append(text)

    # 测试集
    print(len(ldct_list_tokens))
    y_pred_list, y_pred_label_list = get_text_and_label(ldct_list_tokens, y_pred_list)

    print(len(y_pred_label_list))
    dict_data = {
        'y_pred_label_list': y_pred_label_list,
        'ldct_list_text': ldct_list_text,
    }
    df = pd.DataFrame(dict_data)
    df = df.fillna("0")
    df.to_csv(output_path + 'test_result.csv', encoding='utf-8')

4.6模型融合

对多个异构单模的文字结果投票进行多模融合。具体的融合思路如下图所示。我们通过模型构建部分获得了多个异构模型,我们选择高召回率的模型进行单模预测,并将得到的单模结果进行文字投票融合得到最终结果
在这里插入图片描述

5.调试注意事项

(1)自建库或者别人写的非官方库(这里的官方库指的是可以pip/conda install的那种)在pycharm环境中一定要设置成Resource root这样在调用该库时不会报错。
在这里插入图片描述
设置成功后,该文件夹为蓝色,设置过程如下:
在这里插入图片描述

(2)注意初始化文件config.py中一些模型的路径问题:

        self.ensemble_source_file  = 'F:/5.Github项目代码/11.Python深度学习/chapter8/data/ensemble/source_file/'
        self.ensemble_result_file = 'F:/5.Github项目代码/11.Python深度学习/chapter8/data/ensemble/result_file/'

这里建议最好使用“/”而不是windows 文件路径的“\”,因为“\”容易被理解成转义字符
(3)原作者没有给出BERT官方的预训练模型,官方给出了多种不同的BERT模型,这里可以选择合适的模型,放在相应路径即可。下载速度慢的可以参考:https://gitee.com/cheng_jinpei/bert
(4)由于笔者水平和经验有限,该解析文件肯定存在一些漏洞,欢迎大家进行批评与指正

  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值