君正magik将自己的模型量化,如何量化并export为onnx

由于该量化包是基于linux系统上,针对yolo8进行量化,以下是我在量化resnet18的步骤和中所遇到的问题。

1. 首先训练好模型,导出为model.pt。在yolo.py的export方法中,源码为

        from magik.magik_quantize import get_magik_quantizer
        from ingenic_magik_trainingkit.QuantizationTrainingPlugin.python.fqat.quantization.utils import magik_export_onnx

        net = DetectionModel(args.model_cfg, nc=args.nc)
        ckpt = torch.load(args.model_weights, map_location='cpu')
        net.load_state_dict((ckpt['ema'] if ckpt.get('ema') else ckpt['model']).float().state_dict())
        net.cuda().eval()
        net = get_magik_quantizer(args.bit).quant(net)

        if 'ema' in ckpt:
            ckpt['ema'] = net
        else:
            ckpt['model'] = net
        dummy_input = torch.randn(1, 3, 640, 640).cuda()
        magik_export_onnx(net, dummy_input, args.output_file)

是由于yolo需要从yaml文件中加载权重,而由于我们已经有了一个训练好的pt模型,所以代码只需要改为:

from magik.magik_quantize import get_magik_quantizer
        from ingenic_magik_trainingkit.QuantizationTrainingPlugin.python.fqat.quantization.utils import magik_export_onnx
        net = resnet.resnet18(num_classes=args.nc)
        ckpt = torch.load(args.model_weights, map_location="cpu")
        net.load_state_dict(ckpt.float().state_dict())
        net.cuda().eval()
        net = get_magik_quantizer(args.bit).quant(net)

        dummy_input = torch.randn(1, 1, 640, 640).cuda()
        magik_export_onnx(net, dummy_input, args.output_file)

        1. 初始化网络为resnet

        2. load模型,并将模型以state_dict的方式加载

2.  由于在导入模型后,量化的核心代码在

        net = get_magik_quantizer(args.bit).quant(net)
    magik_quantizer = MagikQuantizer(quantize_parameter, cal_data_reader)

首先会进行一个MagikQuantizer的初始化工作,其中quantizer_parameter无需修改,而cal_data_Reader需要给予自己的数据集进行适当修改。若仍然是img的图片数据集,则无需修改,但由于我是对音频进行分类,需要先将音频数据转换为梅尔频谱图,因此仿照Yolov5DataReader自定义一个AudioDataReader继承Yolov3DataReader

class AudioDataReader(Yolov3DataReader):

    def __init__(self, path='', nbatchs=4, batch_size=4, img_size=(640, 640), mean=0, var=255, to_rgb=True):
        assert path != ''
        self.nbatchs = nbatchs
        self.batch_size = batch_size
        self.mean = mean
        self.var = var
        self.to_rgb = to_rgb
        self.spec_len = img_size[0]
        self.data_set = self.load_bin(path, img_size)

    @torch.no_grad()
    def load_audio(self, audio_file, cache=False):
        """
        加载并预处理音频
        :param audio_file:
        :param cache: librosa.load加载音频数据特别慢,建议使用进行缓存进行加速
        :return:
        """
        # 读取音频数据
        cache_path = audio_file + ".pk"
        # t = librosa.get_duration(filename=audio_file)
        if cache and os.path.exists(cache_path):
            tmp = open(cache_path, 'rb')
            wav, sr = pickle.load(tmp)
        else:
            wav, sr = librosa.load(audio_file, sr=16000)
            if cache:
                f = open(cache_path, 'wb')
                pickle.dump([wav, sr], f)
                f.close()

        # Compute a mel-scaled spectrogram: 梅尔频谱图
        spec_image = librosa.feature.melspectrogram(y=wav, sr=sr, hop_length=256)
        return spec_image

    @torch.no_grad()
    def load_bin(self, path, image_size):
        data_list = []
        img_data = open(path, 'r')
        img_list = img_data.read().splitlines()
        batch = []
        for i, img_path in enumerate(img_list):
            img_m = self.load_audio(img_path)
            img = self.pre_process(img_m)

            assert img is not None, f'Image Not Found {path}'
            batch.append(img)
            if ((i+1)%self.batch_size==0):
                data_list.append(torch.cat(batch))
                batch=[]
            if (i>self.nbatchs*self.batch_size):
                break
        return data_list

    @torch.no_grad()
    def pre_process(self, spec_image):
        """音频数据预处理"""
        if spec_image.shape[1] > self.spec_len:
            input = spec_image[:, 0:self.spec_len]
        else:
            input = np.zeros(shape=(self.spec_len, self.spec_len), dtype=np.float32)
            input[:, 0:spec_image.shape[1]] = spec_image
        input = self.normalization(input)
        input = input[np.newaxis, np.newaxis, :]
        input_tensors = np.concatenate([input])
        input_tensors = torch.tensor(input_tensors, dtype=torch.float32)
        return input_tensors

    def normalization(self, spec_image, ymin=0.0, ymax=1.0):
        """
        数据归一化
        """
        spec_image = spec_image.astype(np.float32)
        spec_image = (spec_image - spec_image.min()) / (spec_image.max() - spec_image.min())
        spec_image = spec_image * (ymax - ymin) + ymin
        return spec_image

3. 由于在模型量化过程在,需要通过symbolic tracing将模型转化为图,代码说明可以查看http://pytorch.org/docs/master/fx.html#torch.fx.Tracer

通过符号追踪,对net中每个forward操作进行追踪,然后在graph中记录为node。大致代码如下:

        tracer = Tracer(self.skipped_module_names, self.skipped_module_classes)
        self.graph = tracer.trace(self.model, self._model_call)
        .......
     

但在debug过程中,我发现graph在trace完后为空,但在输出中明明有网络各层结构,仔细排查,并且输出的INFO只有remove各个节点,并没有insert,后发现在Tracer.trace的最后一行有一个

self.cut_graph()
    def cut_graph(self):
        input_nodes = set()
        output_nodes = set()
        modules = dict(self.root.named_modules())
        for node in self.graph.nodes:
            if node.op == 'call_module':
                target = node.target
                module = modules[target]
                if isinstance(module, InputNOp):
                    input_nodes.add(node)
                if isinstance(module, OutputNOp):
                    output_nodes.add(node)


        trace_nodes = set()
        def get_trace_nodes(node):
            if node in trace_nodes:
                return
            trace_nodes.add(node)
            if node in input_nodes:
                return
            for input_node in node.all_input_nodes:
                get_trace_nodes(input_node)
        for output_node in output_nodes:
            get_trace_nodes(output_node)
        self.graph.lint()
        for in_node in input_nodes:
            in_node.args = tuple()
        for node in reversed(self.graph.nodes):
            if node not in trace_nodes:
                logger.info("remove preprocess/postprocess node: %s" % node.target)
                self.graph.erase_node(node)

其中的

                if isinstance(module, InputNOp):
                    input_nodes.add(node)
                if isinstance(module, OutputNOp):
                    output_nodes.add(node)

在结合resnet和yolo的模型区别可以发现,在yolo模型中卷积块都有InputNOP/OutputNOP,如

  (model): Sequential(
    (0): Conv(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (bn): BatchNorm2d(16, eps=0.001, momentum=0.03, affine=True, track_running_stats=True)
      (act): LeakyReLU(negative_slope=0.01, inplace=True)
      (in_nop): InputNOp()
    )

....

      (dfl): DFL(
        (conv): Conv2d(16, 1, kernel_size=(1, 1), stride=(1, 1), bias=False)
      )
      (out_nop): OutputNOp()
    )

因此在原本的resnet中相应的input和output模块插入InputNOP/OutputNOP。

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值