#! /usr/bin/env python # author: qibaoyuan - qibaoyuan@xiaomi.com 2019.08.01 # just demo, no gurantees import collections import time import paddle.fluid as fluid import numpy as np #可以参考ernie的reader包 from paddle.reader.task_reader import ClassifyReader main_prog = fluid.Program() startup_prog = fluid.Program() place = fluid.CPUPlace() exe = fluid.Executor(place) exe.run(startup_prog) path = "~/ERNIE-output-2019-07-29/step_38212_inference_model/step_38212_inference_model" [inference_program, feed_target_names, fetch_targets] = (fluid.io.load_inference_model(dirname=path, executor=exe)) s = time.time() classifyReader = ClassifyReader(vocab_path="~/vocab.txt", is_classify=False, is_inference=True) Example = collections.namedtuple('Example', ['text_a', 'text_b', 'label']) text = "为了预防心脑血管病,老年人需要吃阿司匹林吗? 吃不吃阿司匹林是不是老年人没有直接的关系。" example = Example(text_a=text, text_b='', label='') input_2 = classifyReader._convert_example_to_record(example, classifyReader.max_seq_len, classifyReader.tokenizer) print(input_2) for x in range(50): tensor_img = np.reshape(np.asarray(input_2.token_ids, dtype=np.int), (1, 512, 1)) feed_data = {feed_target_names[0]: tensor_img, feed_target_names[1]: np.zeros((1, 512, 1), dtype=np.int), feed_target_names[2]: np.reshape(np.arange(512), (1, 512, 1)), feed_target_names[3]: np.ones((1, 512, 1), dtype=np.float32), } results = exe.run(inference_program, feed=feed_data, fetch_list=fetch_targets) print(results) print(np.argmax(results)) print(time.time() - s)
使用百度深度学习模型ERNIE对输入的字符串做分类预测(基于fluid)
最新推荐文章于 2024-03-01 11:18:08 发布