我们要进行第三个步骤的操作,就是使用test功能,
elif args.mode == 'test':
ckpt_file = tf.train.latest_checkpoint(model_path)
print(ckpt_file)
paths['model_path'] = ckpt_file
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
print("test data: {}".format(test_size))
#测试
model.test(test_data)
和demo模式一样,首先使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型,然后也是创建了一个BiLSTM_CRF对象,并调用了它的build_graph()函数,这些咱们之前都讲过,调用这个函数就以为已经进行了模型的设置,只等传入feed_dict就可以得到预测结果了。
然后调用了model.test并传入测试数据。
def test(self, test):
saver = tf.train.Saver()
with tf.Session(config=self.config) as sess:
self.logger.info('=========== testing ===========')
saver.restore(sess, self.model_path)
label_list, seq_len_list = self.dev_one_epoch(sess, test)
self.evaluate(label_list, seq_len_list, test)
一句saver.restore引入已经训练好的模型,然后直接使用dev_one_epoce函数和evaluate函数,这两个函数之前我们训练的时候也用到过,主要目的是检测训练集的训练效果在测试集表现怎么样,这里的话省略了训练的步骤,而是调用以前训练好的模型,直接进行测试集的检测。
其实dev_one_epoce就是得到了一个有字,真实标签,预测标签组成的矩阵,但是过程是通过predict_one_batch来得到预测标签的,这样就可以用于evaluate的准确率检验了。
这个模式还是比较简单的,和train模式差不多,只是少了训练的步骤。