torch.jit (Just-In-Time 即时编译器) (动态图转为静态图)(将模型转成TorchScript)

         JIT 是一种概念,全称是 Just In Time Compilation,中文译为「即时编译」,是一种程序优化的方法。如lambda就是靠JIT来对python进行优化

        用 JIT 将 Python 模型转换为 TorchScript Module

       JIT是Python的概念,并不是Pytorch特有。torch.jit是PyTorch的即时编译器(Just-In-Time Compiler,JIT)模块。这个模块存在的意义是把PyTorch的动态图转换成可以优化和序列化的静态图,其主要工作原理是通过输入预先定义好的张量,追踪整个动态图的构建过程,得到最终构建出来的动态图,然后转换为静态图(通过中间表示,即IntermediateRepresentation,来描述最后得到的图)。通过JIT得到的静态图可以被保存,并且被PyTorch其他的前端(如C++语言的前端)支持。另外,JIT也可以用来生成其他格式的神经网络描述文件,如ONNX。需要注意的一点是,torch.jit支持两种模式,即脚本模式(ScriptModule)和追踪模式(Tracing)。前者和后者都能构建静态图,区别在于前者支持控制流,后者不支持,但是前者支持的神经网络模块比后者少,比如脚本模式不支持torch.nn.GRU(详细的描述可以参考PyTorch官方提供的JIT相关的文档)。

示例,模型量化后保存为jit格式

import torch
from torch import nn
from torch import quantization

class ConvBnReluModel(nn.Module):
    def __init__(self) -> None:
        super(ConvBnReluModel, self).__init__()
        self.conv = nn.Conv2d(3,5,3, bias=False)
        self.bn = nn.BatchNorm2d(5)
        self.relu = nn.ReLU(inplace=True)
    
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

m = ConvBnReluModel()
m.eval()
layers = [['conv','bn','relu']]
f = quantization.fuse_modules(m,layers, inplace=True)

types_to_quantize = {nn.Conv2d, nn.BatchNorm2d, nn.ReLU}
q = quantization.quantize_dynamic(f, types_to_quantize, dtype=torch.qint8)

s = torch.jit.script(q)
torch.jit.save(s, 'quantize_model.pth')
    

  • 4
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值