蓝桥杯人工智能赛-模型转换与量化

本篇对简单的进行ONNX模型转换以及TFLite模型量化做出总结。

#task-start
import numpy as np
import onnxruntime as ort
import torch
import torch.nn as nn


class TextClassifier(nn.Module):
    def __init__(self, vocab_size=1000, embed_dim=128, hidden_dim=512, num_classes=2):
        super(TextClassifier, self).__init__()

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn = nn.LSTM(embed_dim, hidden_dim)
        self.fc = nn.Linear(hidden_dim, num_classes)

    def forward(self, text):

        embedded = self.embedding(text)
        packed_output, (hidden, cell) = self.rnn(embedded)
        output = self.fc(hidden.squeeze(0))
        return output


def convert():
    # TODO
    classifier = TextClassifier()
    classifier.load_state_dict(torch.load('model.pt'))   # 加载并导入模型
    sample_input = torch.ones(size=(128, 1), dtype=torch.long)  # 给出示例输入
    torch.onnx.export(
        model=classifier,
        f='text_classifier.onnx',
        args=sample_input,    # 导出的三个重要参数:model(模型), args(示例输入), f(保存路径)
        input_names=['input'],  # 由于一些模型有多输入,因此需要传递字典告诉模型输入名,然而本题仅为单输入
        dynamic_axes={'input': {0 :'batch_size'}}  # 这里至关重要,告诉模型哪里一维可以动态变化(即batch_size),
        #输入格式为嵌套字典,key告诉模型动态变化的输入名, value为字典形式,value_key为动态维度(int), value_value是动态维度名称
    )


def inference(model_path, input):
    # TODO
    session = ort.InferenceSession(model_path, providers=['CPUExecutionProvider'])
    # onnxruntime 加载推理会话,需要指定路径以及providers,本题使用CPU
    input = torch.tensor(input, dtype=torch.long).unsqueeze(0).reshape(-1, 1).detach().numpy()
    # 将input转换为numpy形式
    result = session.run(None, {'input': input})[0].tolist()
    # run的时候需要输入:输出名(这里为None,因为没用到),输入(字典形式,导出时使用的输入名为key,输入array为value)
    # run后得到一个列表[array([.....])],因此处理输出时用了一个索引以及tolist()
    return result


def main():
    convert()
    result = inference('/home/project/text_classifier.onnx', [101, 304, 993, 108,102])
    print(result)


if __name__ == '__main__':
    main()
#task-end

其次是模型的量化,将tensorflow模型量化为tflite格式。

# quantize-start
import tensorflow as tf
import os
import numpy as np
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
import json


def quantize_model(model_path, quantized_model_path):
    # TODO
    model = tf.keras.models.load_model(model_path)  # 老样子,要导出,就要导入一个模型,因此加载模型为第一步 
    converter = tf.lite.TFLiteConverter.from_keras_model(model)  # 加载一个转换器
    converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
    converter.target_spec.supported_types = [tf.float32]  # 设置支持的tensor格式
    quantized_model = converter.convert()   # 开始转换
    open(quantized_model_path, 'wb').write(quantized_model)   # 将量化的模型写入


def prediction_label(test_sentence, model_path):
    # TODO
    test_sentence = [test_sentence]   # 输入为字符串,然而texts_to_sequences的输入为一个列表
    vectorize = Tokenizer()           # 定义一个Tokenizer
    with open('word_index.json', 'r') as f:
        word_index = json.load(f)
    vectorize.word_index = word_index  # 导入word_index,并将文本转换为sequence
    sequence = vectorize.texts_to_sequences(test_sentence)
    pad_sequence = pad_sequences(sequence, maxlen=100)  # 对文本padding以适应输入
    interpreter = tf.lite.Interpreter(model_path=model_path)   # 加载解释器
    interpreter.allocate_tensors()                             # 分配张量
    output = interpreter.get_output_details()[0]               # tflite与onnx不同,解释器中用列表管理输入输出的细节
    input = interpreter.get_input_details()[0]                 
    interpreter.set_tensor(input['index'], np.array(pad_sequence, dtype=np.float32))  # 设置张量,注意,input本质上是一个字典,他的格式为:
    """details = {
        'name': tensor_name,
        'index': tensor_index,
        'shape': tensor_size,
        'dtype': tensor_type,
        'quantization': tensor_quantization,
    }"""
    interpreter.invoke()   # 进行推理测试
    output_value = interpreter.get_tensor(output['index'])[0][0]  # 完成推理测试后,从解释器中得到输出张量,返回numpy格式
    label = 1 if output_value > 0.5 else 0  # 由于是二分类,最后会经过sigmoid
    return label


def main():
    # 量化模型
    quantize_model('/home/project/model.h5', '/home/project/quantized_model.tflite')
    # 测试示例
    test_sentence = "一个 公益广告 : 爸爸 得 了 老年痴呆  儿子 带 他 去 吃饭  盘子 里面 剩下 两个 饺子  爸爸 直接 用手 抓起 饺子 放进 了 口袋  儿子 愣住 了  爸爸 说  我 儿子 最爱 吃 这个 了  最后 广告 出字 : 他 忘记 了 一切  但 从未 忘记 爱 你    「 转 」"
    print(prediction_label(test_sentence, '/home/project/quantized_model.tflite'))


if __name__ == "__main__":
    main()
# quantize-end

本文仅对如何简单的进行转换以及量化给出简要的流程示例,代码在人工智能赛中都可以AC。模型的转换部署中有很多的学问,期待后续能深入学习其中的知识。 

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

藤宫博野

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值