命名实体识别(1)

文章目录


写一下最近正在做的一个命名实体识别项目,还没结束,这里先放一段代码
main.py

#系统包
import os
import tensorflow as tf
import pickle

#自定义包
import data_loader
import data_utils
import model_utils


flags = tf.app.flags

#训练相关的
flags.DEFINE_boolean('train',True,'是否开始训练')
flags.DEFINE_boolean('clean',True,'是否清理文件')

#配置相关
flags.DEFINE_integer('seg_dim',20,'seg embedding size')
flags.DEFINE_integer('word_dim',120,'word embedding')
flags.DEFINE_integer('lstm_dim',120,'Num of hiddem unis in lstm')
flags.DEFINE_string('tag_schema','BIOES','编码方式')


##训练相关的e
flags.DEFINE_float('clip',5,'Grandient clip')
flags.DEFINE_float('dropout',0.5,'Dropout rate')
flags.DEFINE_integer('batch_size',120,'batch_size')
flags.DEFINE_float('lr',0.001,'learning rate')
flags.DEFINE_string('optimizer','adam','优化器')
flags.DEFINE_boolean('pre_emb',True,'是否使用预训练')


flags.DEFINE_integer('max_epoch',100,'最大轮训次数')
flags.DEFINE_integer('steps_chech',100,'steps per checkpoint')
flags.DEFINE_string('ckpt_path',os.path.join('model','ckpt'),'保存模型的位置')
flags.DEFINE_string('log_file','train_log','训练过程中日志')
flags.DEFINE_string('map_file','maps.pkl','存放字典映射以及标签映射')
flags.DEFINE_string('vocab_file','vocab.json','词典')
flags.DEFINE_string('config_file','config_file','配置文件')
flags.DEFINE_string('train_file',os.path.join('data','ner.train'),'训练数据路径')
flags.DEFINE_string('dev_file',os.path.join('data','ner.dev'),'校验数据路径')
flags.DEFINE_string('test_file',os.path.join('data','ner.test'),'测试数据路径')


FLAGS = tf.app.flags.FLAGS
assert FLAGS.clip < 5.1,'梯度裁剪不能过大'
assert 0 < FLAGS.dropout < 1, 'dropout必须在0和1之间'
assert FLAGS.lr >0,'lr 必须大于0'
assert FLAGS.optimizer in ['adam','sgd','adagrad'],'优化器必须在这三者之间'


def train():
    #加载数据
    train_sentences = data_loader.load_sentences(FLAGS.train_file)
    dev_sentences = data_loader.load_sentences(FLAGS.dev_file)
    test_sentences = data_loader.load_sentences(FLAGS.test_file)

    #转换编码bio转bioes
    data_loader.update_tag_scheme(train_sentences,FLAGS.tag_schema)
    data_loader.update_tag_scheme(test_sentences,FLAGS.tag_schema)
    data_loader.update_tag_scheme(dev_sentences,FLAGS.tag_schema)



#创建单词映射
    if not os.path.isfile(FLAGS.map_file):
        _,word_to_id,id_to_word=data_loader.word_mapping(train_sentences)
        _,tag_to_id,id_to_tag = data_loader.tag_mapping(train_sentences)

        with open(FLAGS.map_file,"wb") as f:#第一次会走这里,会创建maps.pkl这个文件
            pickle.dump([word_to_id,id_to_word,tag_to_id,id_to_tag],f)

    else:
        with open(FLAGS.map_file,'rb') as f:#第二次或者以后会走这里
            word_to_id,id_to_word,tag_to_id,id_to_tag = pickle.load(f)


    train_data = data_loader.prepare_dataset(train_sentences,word_to_id,tag_to_id)
    dev_data = data_loader.prepare_dataset(dev_sentences,word_to_id,tag_to_id)
    test_data = data_loader.prepare_dataset(test_sentences,word_to_id,tag_to_id)


    print("train_data_num%i,dev_data_num%i,test_data_num%i"%(len(train_data),len(dev_data),len(test_data)))

    # config = model_utils.config_model(FLAGS, word_to_id, tag_to_id)

    model_utils.make_path(FLAGS)

    if os.path.isfile(FLAGS.config_file):#查看config_file是否存在,存在,则load_cinfig
        config = model_utils.load_config(FLAGS.config_file)
    else:#如果confifile不存在,执行下面语句,先配置,在保存,下一次执行代码时,就执行上面一句了,直接加载!
        config = model_utils.config_model(FLAGS,word_to_id,tag_to_id)
        model_utils.save_config(config,FLAGS.config_file)


    log_path = os.path.join("log",FLAGS.log_file)
    logger = model_utils.get_logger(log_path)
    model_utils.print_config(config,logger)

    print("hello")





def main(_):
    if FLAGS.train:
        train()
    else:
        pass


if __name__ =="__main__":
    tf.app.run(main)



model_utils.py

from  collections import OrderedDict
import os
import json
import logging


def get_logger(log_file):
    """
    定义日志方法(这个方法是通用的)
    :param log_file:
    :return:
    """
    #创建一个logger的实例
    logger =  logging.getLogger(log_file)

    #设置Logger的全局日志级别为DEBUG
    logger.setLevel(logging.DEBUG)

    #创建一个日志文件的handler,并且设置日志级别的DEBUG
    fh = logging.FileHandler(log_file)
    fh.setLevel(logging.DEBUG)
    #创建一个控制台的handler,并且设置日志级别为DEBUG
    ch = logging.StreamHandler()
    ch.setLevel(logging.INFO)
    #设置日志格式
    formatter = logging.Formatter("%(asctime)s-%(name)s-%(levelname)s-%(message)s")
    #add formatter to ch and fh
    ch.setFormatter(formatter)
    fh.setFormatter(formatter)
    #add ch and fh to logger
    logger.addHandler(ch)
    logger.addHandler(fh)
    return logger



#模型配置
def config_model(FLAGS,word_to_id,tag_to_id):
    config = OrderedDict()                          #有序字典
    config['num_words'] = len(word_to_id)
    config['word_dim'] = FLAGS.word_dim
    config['num_tags'] = len(tag_to_id)
    config['seg_dim'] = FLAGS.seg_dim
    config['list_dim'] = FLAGS.lstm_dim
    config['batch_size'] = FLAGS.batch_size

    config['clip'] = FLAGS.clip
    config['dropout_keep'] = 1.0 - FLAGS.dropout
    config['optimizer'] = FLAGS.optimizer
    config['lr'] = FLAGS.lr
    config['tag_schema'] = FLAGS.tag_schema
    config['pre_emb'] = FLAGS.pre_emb
    return config



def make_path(params):
    """
    创建文件夹
    :param params:
    :return:
    """
    if not os.path.isdir(params.ckpt_path):
        os.makedirs(params.ckpt_path)
    if not os.path.isdir('log'):#创建Log文件
            os.makedirs('log')

def save_config(config,config_file):
    """
    保存配置文件
    :param config:
    :param config_path:
    :return:
    """
    with open(config_file,'w',encoding='utf-8')as f:
        json.dump(config,f,ensure_ascii=False,indent =4)


def load_config(config_file):
    """
    加载配置文件
    :param config_file:
    :return:
    """
    with open(config_file,encoding='utf-8') as f:
        return json.load(f)




def print_config(config,logger):#通用方法,打印日志
    """
    打印模型参数
    :param config:
    :param logger:
    :return:
    """
    for k,v in config.items():
        logger.info("{}:\t{}".format(k.ljust(15),v))


data_loader.py



import codecs
import data_utils
def load_sentences(path):
    """
    加载数据集,每行包含一个汉字和一个标记
    句子和句子之间是以空格进行分割的
    最后返回句子集合
    :param path:
    :return:
    """
    #存放数据集
    sentences = []
    #临时存放每一个句子
    sentence = []
    for line in codecs.open(path,'r',encoding='utf-8'):#这个循环结束,将会将数据添加到sentences
        #去掉两边空格
        line = line.strip()
        #首先判断是不是空,如果是则表示句子和句子之间的分割点
        if not line:#如果这行是空格
            if len(sentence) > 0:#判断这个句子里面是否为空(长度是否大于0), 在这里下断点,利用resume program这个键(跳到下个断点),这里相当于是从一个句子直接跳至这个句子结束
                sentences.append(sentence)#如果大于0,说明句子里面有东西,将其添加到句子集合里面
                #清空sentence,表示一句话完结
                sentence = []              #由于下面还需要用到这个临时的,故这里将其清空
        else:       #如果这行不是空
            if line[0] == " ":#判断第一个是否为空,如果为空,表示这不是一个合法的
                continue
            else:#如果第一个不是空的
                word = line.split()#使用空格对字符串进行切分
                assert len(word) >= 2
                sentence.append(word)#将上面的信息添加到句子里面
    #循环走完,要判断一个,防止句子没有进入到句子集合里面
    if len(sentence) > 0:
        sentences.append(sentence)
    return sentences

def update_tag_scheme(sentences,tag_scheme):
    """
    更新为指定编码
    :param sentences:
    :param tag_scheme:
    :return:
    """
    for i,s in enumerate(sentences):
        tags = [w[-1] for w in s ]      #取出编码
        if not data_utils.chek_bio(tags):#如果不是bio编码(做转换之前校验一下是否为我们的BIO编码)
            s_str ="\n".join("".join(w) for w in s)
            raise Exception("输入的句子应为BIO编码,请检查输入句子%i:\n%s"%(i,s_str))

        if tag_scheme == "BIO":
            for word,new_tag in zip(s,tags):
                word[-1] = new_tag

        if tag_scheme == "BIOES":
            new_tags =data_utils.bio_to_bioes(tags)
            for word, new_tag in zip(s,new_tags):
                word[-1] = new_tag
        else:
            raise Exception("非法目标编码")

def word_mapping(sentences):            #这个函数在NLP领域内经常用到!!比如分类
    """
    构建字典
    :param sentences:
    :return:
    """
    #这里有个列表推导式,这是个重点!!!
    word_list = [ [x[0] for x in s]for s in sentences]#将每个句子里面的word提炼出来,eg:[['相', 'O'], ['比', 'O'], ['之', 'O'], ['下', 'O'], [',', 'O'], ['青', 'B-ORG']],将里面的"相比之下"这些字提炼出来了
    dico = data_utils.create_dico(word_list)            #dico里面存放的是每个单词以及其对应的次数,
    dico['<PAD>'] = 10000001
    dico['<UNK>'] = 10000000
    word_to_id,id_to_word =data_utils.create_mapping(dico)
    return dico,word_to_id,id_to_word


def tag_mapping(sentences):     #序列标注的时候会用到
    """
    构建标签字典
    :param sentences:
    :return:
    """
    tag_list = [[ x[1] for x in s] for s in sentences]
    dico = data_utils.create_dico(tag_list)
    tag_to_id,id_to_tag =data_utils.create_mapping(dico)
    return dico,tag_to_id,id_to_tag



def  prepare_dataset(sentences,word_to_id,tag_to_id,train = True):
    """
    数据预处理,返回list,其实包含:
    -word_list
    -word_id_list
    -word char indexs
    -tag_id_list
    :param sentences:
    :param word_to_id:
    :param tag_to_id:
    :param train:
    :return:
    """
    none_index = tag_to_id['O']#字母O

    data =[]
    for s in sentences:
        word_list = [w[0] for w in s]
        word_id_list = [ word_to_id [w if w in word_to_id else '<UNK>'] for w in word_list]#遍历word_list,由于集合里面不可能包含字典里面的词,判断一下,如果在word_to_id里面就取w,否则就取UNK(表示不在字典里面)
        segs  = data_utils.get_seg_features("".join(word_list))
        if train:
            tag_id_list = [tag_to_id[w[-1]] for w in s]#tag_to_id[w[-1]]:将tag拿出来
        else:
            tag_id_list = [none_index for w in s]
        data.append([word_list,word_id_list,segs,tag_id_list])
    return data



if __name__ =="__main__":
    path = "data/ner.dev"
    sentences = load_sentences(path)
    update_tag_scheme(sentences,"BIOES")
    _,word_to_id,id_to_word= word_mapping(sentences)
    _,tag_to_id,id_to_tag=tag_mapping(sentences)        #_:表示默认值
    dev_data = prepare_dataset(sentences,word_to_id,tag_to_id)
    data_utils.BatchManager(dev_data,120)

data_utils.py

import jieba
import math
import random

def chek_bio(tags):
    """
     检测输入的tags是否为BIO编码
     如果不是bio编码
     那么错误的类型
     1)编码不在BIO中
     2)第一个编码是I
     3)当前编码不是B,前一个编码不是O
    :param tags:
    :return:
    """
    for i,tag in enumerate(tags):
        if tag =='O':#此时为BIO
            continue
        tag_list = tag.split("-")
        if len(tag_list) != 2 or tag_list[0] not in set(['B','I']):#此时为非法编码,分割之后的长度不是2,同时编码不在B和I中;
            return False
        if tag_list[0] == 'B':#此时为合法BIO编码
            continue
        elif i == 0 or tags[i-1] == 'O':#如果第一个位置不是B,同时i等于0(I是第一个位置),上一个编码等于O(字母)
            tags[i] ='B' + tag[1:]     #如果当前第一个位置不是B,或者当前编码不是B并且前一个编码0,则全部转换成
        elif tags[i-1][1:] ==tag[1:]:
            #如果当前编码的后面类型编码与tags中的前一个编码中的后面类型编码相同,则跳过
            continue
        else:
            #如果编码类型不一致,则重新从B开始
            tags[i] = 'B' + tag[1:]
    return True


def bio_to_bioes(tags):
    """
    把bio编码转换成bios
    返回新的tags
    :param tags:
    :return:
    """
    new_tags = []
    for i ,tag in enumerate(tags):
        if tag =='O':
            #直接保留,不变化
            new_tags.append(tag)
        elif tag.split('-')[0] =='B':
            #如果tag是以B开头,那么我们就要做下面的判断:
            #首先,如果当前tag不是最后一个,并且紧跟着的后一个是I,eg:B-ORG后面是I-ORG
            if (i + 1) < len(tags) and tags[i +1].split('-')[0] =='I':#i + 1 < len(tags):如果不是最后一个tag
                #直接保留
                new_tags.append(tag)
            else:#如果是最后一个或者后面一个不是I;eg:B-ORG后面一个也是B-ORG,那么前面这个B-ORG就会变成S-ORG
                #如果是最后一个或者紧跟着的后一个不是I,那么表示,需要把B换成S表示单字
                new_tags.append(tag.replace('B-','S-'))
        elif tag.split('-')[0]=='I':
            #如果tag是以I开头,那么我们需要进行下面的判断
            #首先,如果当前tag不是最后一个,并且紧跟着的一个是I,eg:I-ORG后面还是I-ORG,如果I-ORG后面是O(字母)就不行
            if (i + 1)<len(tags) and tags[i+1].split('-')[0]=='I':
                #直接保留
                new_tags.append(tag)
            else:
                #如果是最后一个或者I-ORG后面一个不是以I开头的(也不是以B开头的),那么就表示一个词的结尾,就把I换成E表示一个词的结尾
                new_tags.append(tag.replace('I-','E-'))
        else:
            raise  Exception('非法编码')
    return new_tags


def create_dico(item_list):
    """
    对于item_list里面,每个item,统计item_list中item在item_list的的次数
    item:出现的次数
    :param item_list:
    :return:
    """
    assert type(item_list) is list
    dico = {}
    for items in item_list:
        for item in items:
            if item not in dico:
                dico[item] = 1
            else:
                dico[item] += 1
    return dico


def create_mapping(dico):
    """
    创建item to id,id to item
    item的排序按照词典中出现的次数

    :param dico:
    :return:
    """
    sorted_items = sorted(dico.items(),key=lambda x:(-x[1],x[0]))  #将dico(字典)里面的key按照降序排
    # sorted_items = sorted(dico.items(),key=lambda x:(x[1],x[0]))  #将dico(字典)里面的key按照升序排
    id_to_item = {i:v[0] for i,v in enumerate(sorted_items)}        #将字典里面的每个字进行编号,出现的次数越多,排在越前面,编号越小;这里的形式为---》编号: 词
    item_to_id = {v:k for k,v in id_to_item.items()}                #这里的形式为---》词:编号
    return item_to_id,id_to_item                                    #注意智力的顺序


def get_seg_features(words):
    """
    利用jieba分词
    采用类似bioes的编码,0表示单个字成词,1表示一个词的开始,2表示一个词的中间,3表示一个词的结尾
    :param words:
    :return:
    """
    seg_features = []
    word_list = list(jieba.cut(words))

    for word in word_list:
        if len(word) ==1:           #表示单个成词,就填0
            seg_features.append(0)
        else:
            temp = [2]*len(word)
            temp[0] = 1
            temp[-1] = 3
            seg_features.extend(temp)
    return seg_features


class BatchManager(object):
    def __init__(self,data,batch_size):
        self.batch_data = self.sort_and_pad(data,batch_size)
        self.len_data = len(self.batch_data)

    def sort_and_pad(self,data,batch_size):
        num_batch = int(math.ceil(len(data)/batch_size))    #计算有多少批次
        sorted_data = sorted(data,key=lambda x :len(x[0])) #这里是按照len的升序排
        # sorted_data1 = sorted(data,key=lambda x :-len(x[0])) #这里是按照len(长度)的降序排
        batch_data = list()
        for i in range(num_batch):
            batch_data.append(self.pad_data(sorted_data[i*batch_size :(i+1)*batch_size]))

        return batch_data

    @staticmethod
    def pad_data(data):#数据填充函数
        word_list = []
        word_id_list =[]
        seg_list=[]
        tag_id_list =[]
        max_length =max( [len(sentence[0]) for sentence in data])   #一批有120(自己设置的)个句子(样本),最长数据是17(这里的demo是17,下一批最大是20,每一批的最大值不一样的!以每批有120个样本计算,这里有20批)
        for line in data:
            words,word_ids,segs,tag_ids = line      #单词,单词索引,分词信息(分词特征信息),tag索引
            padding = [0] *(max_length - len(words))        #需要填充的数据
            word_list.append(words + padding)
            word_id_list.append(word_ids + padding)
            seg_list.append(segs + padding)
            tag_id_list.append(tag_ids + padding)
        return [word_list,word_id_list,seg_list,tag_id_list]


    def iter_batch(self,shuffle=False):
        if shuffle:
            random.shuffle(self.batch_data)
        for idx in range(self.len_data):
            yield self.batch_data[idx]



后面代码可能会做补充和拓展

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值