bilstm-crf

model.py

import numpy as np
import os, time, sys
import tensorflow as tf
from tensorflow.contrib.rnn import LSTMCell
from tensorflow.contrib.crf import crf_log_likelihood
from tensorflow.contrib.crf import viterbi_decode
from data import pad_sequences, batch_yield
from utils import get_logger
from eval import conlleval


class BiLSTM_CRF(object):
    def __init__(self, args, embeddings, tag2label, vocab, paths, config):
        self.batch_size = args.batch_size
        self.epoch_num = args.epoch
        self.hidden_dim = args.hidden_dim
        self.embeddings = embeddings
        self.CRF = args.CRF
        self.update_embedding = args.update_embedding
        self.dropout_keep_prob = args.dropout
        self.optimizer = args.optimizer
        self.lr = args.lr
        self.clip_grad = args.clip
        self.tag2label = tag2label
        self.num_tags = len(tag2label)
        self.vocab = vocab
        self.shuffle = args.shuffle
        self.model_path = paths['model_path']
        self.summary_path = paths['summary_path']
        self.logger = get_logger(paths['log_path'])
        self.result_path = paths['result_path']
        self.config = config

    def build_graph(self):
        self.add_placeholders()
        self.lookup_layer_op()
        self.biLSTM_layer_op()
        self.softmax_pred_op()
        self.loss_op()
        self.trainstep_op()
        self.init_op()

    def add_placeholders(self):
        self.word_ids = tf.placeholder(tf.int32, shape=[None, None], name="word_ids")
        self.labels = tf.placeholder(tf.int32, shape=[None, None], name="labels")
        self.sequence_lengths = tf.placeholder(tf.int32, shape=[None], name="sequence_lengths")

        self.dropout_pl = tf.placeholder(dtype=tf.float32, shape=[], name="dropout")
        self.lr_pl = tf.placeholder(dtype=tf.float32, shape=[], name="lr")

    def lookup_layer_op(self):
        with tf.variable_scope("words"):
            _word_embeddings = tf.Variable(self.embeddings,
                                           dtype=tf.float32,
                                           trainable=self.update_embedding,
                                           name="_word_embeddings")
            word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings,
                                                     ids=self.word_ids,
                                                     name="word_embeddings")
        self.word_embeddings =  tf.nn.dropout(word_embeddings, self.dropout_pl)

    def biLSTM_layer_op(self):
        with tf.variable_scope("bi-lstm"):
            cell_fw = LSTMCell(self.hidden_dim)
            cell_bw = LSTMCell(self.hidden_dim)
            (output_fw_seq, output_bw_seq), _ = tf.nn.bidirectional_dynamic_rnn(
                cell_fw=cell_fw,
                cell_bw=cell_bw,
                inputs=self.word_embeddings,
                sequence_length=self.sequence_lengths,
                dtype=tf.float32)
            output = tf.concat([output_fw_seq, output_bw_seq], axis=-1)
            output = tf.nn.dropout(output, self.dropout_pl)

        with tf.variable_scope("proj"):
            W = tf.get_variable(name="W",
                                shape=[2 * self.hidden_dim, self.num_tags],
                                initializer=tf.contrib.layers.xavier_initializer(),
                                dtype=tf.float32)

            b = tf.get_variable(name="b",
                                shape=[self.num_tags],
                                initializer=tf.zeros_initializer(),
                                dtype=tf.float32)

            s = tf.shape(output)
            output = tf.reshape(output, [-1, 2*self.hidden_dim])
            pred = tf.matmul(output, W) + b

            self.logits = tf.reshape(pred, [-1, s[1], self.num_tags])

    def loss_op(self):
        if self.CRF:
            log_likelihood, self.transition_params = crf_log_likelihood(inputs=self.logits,
                                                                   tag_indices=self.labels,
                                                                   sequence_lengths=self.sequence_lengths)
            self.loss = -tf.reduce_mean(log_likelihood)

        else:
            losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits,
                                                                    labels=self.labels)
            mask = tf.sequence_mask(self.sequence_lengths)
            losses = tf.boolean_mask(losses, mask)
            self.loss = tf.reduce_mean(losses)

        tf.summary.scalar("loss", self.loss)

    def softmax_pred_op(self):
        if not self.CRF:
            self.labels_softmax_ = tf.argmax(self.logits, axis=-1)
            self.labels_softmax_ = tf.cast(self.labels_softmax_, tf.int32)

    def trainstep_op(self):
        with tf.variable_scope("train_step"):
            self.global_step = tf.Variable(0, name="global_step", trainable=False)
            if self.optimizer == 'Adam':
                optim = tf.train.AdamOptimizer(learning_rate=self.lr_pl)
            elif self.optimizer == 'Adadelta':
                optim = tf.train.AdadeltaOptimizer(learning_rate=self.lr_pl)
            elif self.optimizer == 'Adagrad':
                optim = tf.train.AdagradOptimizer(learning_rate=self.lr_pl)
            elif self.optimizer == 'RMSProp':
                optim = tf.train.RMSPropOptimizer(learning_rate=self.lr_pl)
            elif self.optimizer == 'Momentum':
                optim = tf.train.MomentumOptimizer(learning_rate=self.lr_pl, momentum=0.9)
            elif self.optimizer == 'SGD':
                optim = tf.train.GradientDescentOptimizer(learning_rate=self.lr_pl)
            else:
                optim = tf.train.GradientDescentOptimizer(learning_rate=self.lr_pl)

            grads_and_vars = optim.compute_gradients(self.loss)
            grads_and_vars_clip = [[tf.clip_by_value(g, -self.clip_grad, self.clip_grad), v] for g, v in grads_and_vars]
            self.train_op = optim.apply_gradients(grads_and_vars_clip, global_step=self.global_step)

    def init_op(self):
        self.init_op = tf.global_variables_initializer()

    def add_summary(self, sess):
        """

        :param sess:
        :return:
        """
        self.merged = tf.summary.merge_all()
        self.file_writer = tf.summary.FileWriter(self.summary_path, sess.graph)

    def train(self, train, dev):
        """

        :param train:
        :param dev:
        :return:
        """
        saver = tf.train.Saver(tf.global_variables())

        with tf.Session(config=self.config) as sess:
            sess.run(self.init_op)
            self.add_summary(sess)

            for epoch in range(self.epoch_num):
                self.run_one_epoch(sess, train, dev, self.tag2label, epoch, saver)

    def test(self, test):
        saver = tf.train.Saver()
        with tf.Session(config=self.config) as sess:
            self.logger.info('=========== testing ===========')
            saver.restore(sess, self.model_path)
            label_list, seq_len_list = self.dev_one_epoch(sess, test)
            self.evaluate(label_list, seq_len_list, test)

    def demo_one(self, sess, sent):
        """

        :param sess:
        :param sent: 
        :return:
        """
        label_list = []
        for seqs, labels in batch_yield(sent, self.batch_size, self.vocab, self.tag2label, shuffle=False):
            label_list_, _ = self.predict_one_batch(sess, seqs)
            label_list.extend(label_list_)
        label2tag = {}
        for tag, label in self.tag2label.items():
            label2tag[label] = tag if label != 0 else label
        tag = [label2tag[label] for label in label_list[0]]
        return tag

    def run_one_epoch(self, sess, train, dev, tag2label, epoch, saver):
        """

        :param sess:
        :param train:
        :param dev:
        :param tag2label:
        :param epoch:
        :param saver:
        :return:
        """
        num_batches = (len(train) + self.batch_size - 1) // self.batch_size

        start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        batches = batch_yield(train, self.batch_size, self.vocab, self.tag2label, shuffle=self.shuffle)
        for step, (seqs, labels) in enumerate(batches):

            sys.stdout.write(' processing: {} batch / {} batches.'.format(step + 1, num_batches) + '\r')
            step_num = epoch * num_batches + step + 1
            feed_dict, _ = self.get_feed_dict(seqs, labels, self.lr, self.dropout_keep_prob)
            _, loss_train, summary, step_num_ = sess.run([self.train_op, self.loss, self.merged, self.global_step],
                                                         feed_dict=feed_dict)
            if step + 1 == 1 or (step + 1) % 300 == 0 or step + 1 == num_batches:
                self.logger.info(
                    '{} epoch {}, step {}, loss: {:.4}, global_step: {}'.format(start_time, epoch + 1, step + 1,
                                                                                loss_train, step_num))

            self.file_writer.add_summary(summary, step_num)

            if step + 1 == num_batches:
                saver.save(sess, self.model_path, global_step=step_num)

        self.logger.info('===========validation / test===========')
        label_list_dev, seq_len_list_dev = self.dev_one_epoch(sess, dev)
        self.evaluate(label_list_dev, seq_len_list_dev, dev, epoch)

    def get_feed_dict(self, seqs, labels=None, lr=None, dropout=None):
        """

        :param seqs:
        :param labels:
        :param lr:
        :param dropout:
        :return: feed_dict
        """
        word_ids, seq_len_list = pad_sequences(seqs, pad_mark=0)

        feed_dict = {self.word_ids: word_ids,
                     self.sequence_lengths: seq_len_list}
        if labels is not None:
            labels_, _ = pad_sequences(labels, pad_mark=0)
            feed_dict[self.labels] = labels_
        if lr is not None:
            feed_dict[self.lr_pl] = lr
        if dropout is not None:
            feed_dict[self.dropout_pl] = dropout

        return feed_dict, seq_len_list

    def dev_one_epoch(self, sess, dev):
        """

        :param sess:
        :param dev:
        :return:
        """
        label_list, seq_len_list = [], []
        for seqs, labels in batch_yield(dev, self.batch_size, self.vocab, self.tag2label, shuffle=False):
            label_list_, seq_len_list_ = self.predict_one_batch(sess, seqs)
            label_list.extend(label_list_)
            seq_len_list.extend(seq_len_list_)
        return label_list, seq_len_list

    def predict_one_batch(self, sess, seqs):
        """

        :param sess:
        :param seqs:
        :return: label_list
                 seq_len_list
        """
        feed_dict, seq_len_list = self.get_feed_dict(seqs, dropout=1.0)

        if self.CRF:
            logits, transition_params = sess.run([self.logits, self.transition_params],
                                                 feed_dict=feed_dict)
            label_list = []
            for logit, seq_len in zip(logits, seq_len_list):
                viterbi_seq, _ = viterbi_decode(logit[:seq_len], transition_params)
                label_list.append(viterbi_seq)
            return label_list, seq_len_list

        else:
            label_list = sess.run(self.labels_softmax_, feed_dict=feed_dict)
            return label_list, seq_len_list

    def evaluate(self, label_list, seq_len_list, data, epoch=None):
        """

        :param label_list:
        :param seq_len_list:
        :param data:
        :param epoch:
        :return:

        tag, label: O 0
        tag, label: B 1
        tag, label: I 2
        tag, label: E 3

        tag2label = {"O": 0, "B": 1, "I": 2, "E":3}
        label2tag = {0: 0, 1: 'B', 2: 'I', 3: 'E'}
        
        model_predict[0]:
        [['随', 'O', 0], ['着', 'O', 0], ['电', 'B', 'B'], ['商', 'I', 'I'], ['平', 'I', 'I'], ['台', 'E', 'E']]
        """
        label2tag = {}
        for tag, label in self.tag2label.items():
            print("tag, label:",tag, label)
            # label2tag[label] = tag if label != 0 else label
            # 更改:
            if label != 0:
                label2tag[label] = tag
            else:
                label2tag[label] = label

        model_predict = [] # 存储预测结果
        for label_, (sent, tag) in zip(label_list, data):
            tag_ = [label2tag[label__] for label__ in label_]
            sent_res = []
            if  len(label_) != len(sent):
                print(sent)
                print(len(label_))
                print(tag)
            for i in range(len(sent)):
                sent_res.append([sent[i], tag[i], tag_[i]])
            model_predict.append(sent_res) # sent_res: 
        print("model_predict0:\n", model_predict[0])
        # print("model_predict1:\n", model_predict[1])
        epoch_num = str(epoch+1) if epoch != None else 'test'
        label_path = os.path.join(self.result_path, 'label_' + epoch_num)
        metric_path = os.path.join(self.result_path, 'result_metric_' + epoch_num)
        for _ in conlleval(model_predict, label_path, metric_path):
            self.logger.info(_)


main.py

import tensorflow as tf
import numpy as np
import os, argparse, time, random
from model import BiLSTM_CRF
from utils import str2bool, get_logger, get_entity, get_PRO_entity
from linking import linking
from data import read_corpus, read_dictionary, tag2label, random_embedding


## Session configuration
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'  # default: 0
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.8  # need ~700MB GPU memory


## hyperparameters
parser = argparse.ArgumentParser(description='BiLSTM-CRF for Chinese NER task')
parser.add_argument('--data_path', type=str, default='final/hadoop_train0/data_path', help='train data source')
# parser.add_argument('--test_data', type=str, default='final/data_path', help='test data source')
parser.add_argument('--batch_size', type=int, default=1, help='#sample of each minibatch')
parser.add_argument('--epoch', type=int, default=40, help='#epoch of training')
parser.add_argument('--hidden_dim', type=int, default=300, help='#dim of hidden state')
parser.add_argument('--optimizer', type=str, default='SGD', help='Adam/Adadelta/Adagrad/RMSProp/Momentum/SGD')
parser.add_argument('--CRF', type=str2bool, default=True, help='use CRF at the top layer. if False, use Softmax')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate')
parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
parser.add_argument('--dropout', type=float, default=0.5, help='dropout keep_prob')
parser.add_argument('--update_embedding', type=str2bool, default=True, help='update embedding during training')
parser.add_argument('--pretrain_embedding', type=str, default='pre', help='use pretrained char embedding or init it randomly, pre/random')
parser.add_argument('--embedding_dim', type=int, default=300, help='random init char embedding_dim')
parser.add_argument('--shuffle', type=str2bool, default=True, help='shuffle training data before each epoch')
parser.add_argument('--mode', type=str, default='demo', help='train/test/demo/eval_data')
parser.add_argument('--demo_model', type=str, default='1521112368', help='model for test and demo')
# parser.add_argument('--demo_data_path', type=str, default='final/data_path', help='demo data path')
args = parser.parse_args()


## get char embeddings
print("=== load word2id...")
word2id = read_dictionary(os.path.join('.', args.data_path, 'word2id.pkl'))
if args.pretrain_embedding == 'random':
    print("=== random embedding...")
    embeddings = random_embedding(word2id, args.embedding_dim)
    print("== embeddings type:", type(embeddings))
    print("== embeddings shape:", embeddings.shape) # 2960 * 300
    print("embeddings[0]:", embeddings[0]) # 1 * 300
else:
    print("=== pretrain embedding...")
    embedding_path = './pretrain_embedding.npy'
    embeddings = np.array(np.load(embedding_path), dtype='float32')


## read corpus and get training data
if args.mode != 'demo':
    train_path = os.path.join('.', args.data_path, 'train_data')
    print("=== train_path:", train_path)
    test_path = os.path.join('.', args.data_path, 'test_data') # test_data和train_data 路径相同
    print("=== test_path:", test_path)
    
    train_data = read_corpus(train_path)
    print("=== train data size:", len(train_data))
    test_data = read_corpus(test_path)
    test_size = len(test_data)
    print("=== test data size:", test_size)


## paths setting
paths = {}
timestamp = str(int(time.time())) if args.mode == 'train' else args.demo_model
print("=== model_name:", timestamp)
output_path = os.path.join('.', args.data_path+"_save", timestamp)
if not os.path.exists(output_path): os.makedirs(output_path)
summary_path = os.path.join(output_path, "summaries")
paths['summary_path'] = summary_path
if not os.path.exists(summary_path): os.makedirs(summary_path)
model_path = os.path.join(output_path, "checkpoints/")
if not os.path.exists(model_path): os.makedirs(model_path)
ckpt_prefix = os.path.join(model_path, "model")
paths['model_path'] = ckpt_prefix
result_path = os.path.join(output_path, "results")
paths['result_path'] = result_path
if not os.path.exists(result_path): os.makedirs(result_path)
log_path = os.path.join(result_path, "log.txt")
paths['log_path'] = log_path
get_logger(log_path).info(str(args))


## training model
if args.mode == 'train':
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()

    ## hyperparameters-tuning, split train/dev
    # dev_data = train_data[:5000]; dev_size = len(dev_data)
    # train_data = train_data[5000:]; train_size = len(train_data)
    # print("train data: {0}\ndev data: {1}".format(train_size, dev_size))
    # model.train(train=train_data, dev=dev_data)

    ## train model on the whole training data
    print("=== train data size: {}".format(len(train_data)))
    model.train(train=train_data, dev=test_data)  # use test_data as the dev_data to see overfitting phenomena

## testing model
if args.mode == 'test':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    print("test data: {}".format(test_size))
    model.test(test_data)


## demo
if args.mode == 'demo':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print(ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        while(1):
            print('Please input your sentence:')
            demo_sent = input()
            if demo_sent == '' or demo_sent.isspace():
                print('See you next time!')
                break
            else:
                demo_sent = list(demo_sent.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                # print("sess:\n", sess)
                print("demo_data:\n", demo_data)
                tag = model.demo_one(sess, demo_data)
                print("tag:\n", tag)
                PRO = get_PRO_entity(tag, demo_sent)
                print('PRO: {}\n'.format(PRO))

                # PER, LOC, ORG = get_entity(tag, demo_sent)
                # print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))

'''
## my predict:
## 自评数据:适用于只有两列数据的待预测文本
# demo_data_path = "./linking/eval_pm"
# demo_data_predict_path = "./linking/eval_pm_pre_tf"
demo_data_path = "./linking/eval_data"
demo_data_predict_path = "./linking/eval_data_pre_tobedelete"
if args.mode == 'eval_data':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print("ckpt:\n", ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()

    fin = open(demo_data_path)
    fout = open(demo_data_predict_path, mode="w", encoding="utf8")

    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        for line in fin:
            sent = line.strip().split("\t")[0] # 句子
            sequence = line.strip().split("\t")[1] # 标注

            demo_sent = list(sent.strip())
            demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            tag = model.demo_one(sess, demo_data)
            # 找到的实体列表
            PRO = get_PRO_entity(tag, demo_sent)
            print("\nsent:", sent)
            print("PRO:", PRO)
            LINK = linking.ttt(PRO)
            print("LINK:", LINK)
            

            line = "{}\t{}\t{}\t{}\n".format(sent, sequence, PRO, LINK)
            fout.write(line)
            
    fin.close()
    fout.close()
'''

'''
## my predict:
## 适用于pm给的待预测文本
demo_data_path = "./linking/eval_pm"
demo_data_predict_path = "./linking/eval_pm_pre_tf"
if args.mode == 'eval_data':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print("ckpt:\n", ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()

    with open(demo_data_path) as f:
        lines = f.readlines()[1:]
    fout = open(demo_data_predict_path, mode="w", encoding="utf8")
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        for line in lines:
            temp = line.rstrip().split("\t")
            pro_content = temp[0].split("[SEP]") # 项目内容相关
            # print(pro_content)
            # pro_name = pro_content[0] # 项目名字
            # pro_res = pro_content[1] # 项目责任
            # pro_des = pro_content[2] # 项目描述
            sentence = " ".join(pro_content)
            tag = temp[1] # 标注

            # 找到的实体列表
            # PROs = []
            # for i in [pro_name, pro_res, pro_des]:
            #     demo_sent = list(i.strip())
            #     demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            #     demo_tag = model.demo_one(sess, demo_data)
            #     PRO = get_PRO_entity(demo_tag, demo_sent)
            #     PROs += PRO
            demo_sent = list(sentence.strip())
            demo_data = [(demo_sent, ['O'] * len(demo_sent))]
            demo_tag = model.demo_one(sess, demo_data)
            PRO = get_PRO_entity(demo_tag, demo_sent)
            print("\nsent:", temp[0])
            print("PROs:", PRO)
            LINK = linking.ttt(PRO)
            print("LINK:", LINK)
            

            line = "{}\t{}\t{}\t{}\n".format(temp[0], tag, PRO, LINK)
            fout.write(line)
            
    fout.close()
'''

'''
## my predict:
## 适用于pm给的待预测文本
demo_data_path = "./linking/eval_pm"
demo_data_predict_path = "./linking/eval_pm_pre_tf"
if args.mode == 'eval_data':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print("ckpt:\n", ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()

    with open(demo_data_path) as f:
        lines = f.readlines()[1:]
    fout = open(demo_data_predict_path, mode="w", encoding="utf8")
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        for line in lines:
            temp = line.rstrip().split("\t")
            pro_content = temp[0].split("[SEP]") # 项目内容相关
            # print(pro_content)
            # pro_name = pro_content[0] # 项目名字
            # pro_res = pro_content[1] # 项目责任
            pro_des = pro_content[2] # 项目描述
            PRO = [] # 模型找到的实体
            sentence = pro_des.strip() # 项目名字
            if sentence:
                # print("sentence:", sentence)
                # sentence = " ".join(pro_content)
                # tag = temp[1] # 标注

                # 找到的实体列表
                # PROs = []
                # for i in [pro_name, pro_res, pro_des]:
                #     demo_sent = list(i.strip())
                #     demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                #     demo_tag = model.demo_one(sess, demo_data)
                #     PRO = get_PRO_entity(demo_tag, demo_sent)
                #     PROs += PRO
                demo_sent = list(sentence.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                demo_tag = model.demo_one(sess, demo_data)
                PRO = get_PRO_entity(demo_tag, demo_sent)
                print("\nsent:", sentence)
                print("PRO:", PRO)
                # LINK = linking.ttt(PRO)
                # print("LINK:", LINK)
            line = "{}\n".format(PRO)
            fout.write(line)
            
    fout.close()
'''


## my predict:
demo_data_path = "./linking/eval_data"
demo_data_predict_path = "./linking/eval_data_new"
if args.mode == 'eval_data':
    ckpt_file = tf.train.latest_checkpoint(model_path)
    print("ckpt:\n", ckpt_file)
    paths['model_path'] = ckpt_file
    model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
    model.build_graph()
    saver = tf.train.Saver()

    with open(demo_data_path) as f:
        lines = f.readlines()
    fout = open(demo_data_predict_path, mode="w", encoding="utf8")
    with tf.Session(config=config) as sess:
        print('============= demo =============')
        saver.restore(sess, ckpt_file)
        for line in lines:
            PRO = [] # 模型找到的实体
            sentence = line.rstrip() # 项目名字
            if sentence:
                # print("sentence:", sentence)
                # sentence = " ".join(pro_content)
                # tag = temp[1] # 标注

                # 找到的实体列表
                # PROs = []
                # for i in [pro_name, pro_res, pro_des]:
                #     demo_sent = list(i.strip())
                #     demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                #     demo_tag = model.demo_one(sess, demo_data)
                #     PRO = get_PRO_entity(demo_tag, demo_sent)
                #     PROs += PRO
                demo_sent = list(sentence.strip())
                demo_data = [(demo_sent, ['O'] * len(demo_sent))]
                demo_tag = model.demo_one(sess, demo_data)
                PRO = get_PRO_entity(demo_tag, demo_sent)
                print("\nsent:", sentence)
                print("PRO:", PRO)
                # LINK = linking.ttt(PRO)
                # print("LINK:", LINK)
            line = "{}\n".format(PRO)
            fout.write(line)
            
    fout.close()
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值