torch.fx模型量化进阶

一、定义

  1. 定义
  2. 案例-图-算子追踪
  3. 算子替换
  4. 新建算子并替换
  5. 子图替换
  6. fx 模型量化
  7. 接口解读

二、实现

  1. 定义
    1.1 torch.fx设计的目标就是在图上做各种变换,以完成图优化、量化等图功能性的改变。
    1.2 在不改变原模型的基础上修改模型。
    1.3 计算图重写
  2. 案例-图-算子追踪
import torch
from torch.fx import symbolic_trace

# Simple module for demonstration
class MyModule(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

module = MyModule()
# 符号追踪这个模块
# Symbolic tracing frontend - captures the semantics of the module
symbolic_traced = symbolic_trace(module)
# 中间表示
# High-level intermediate representation (IR) - Graph representation
print(symbolic_traced.graph)    #追踪模型
  1. 算子替换
# Simple module for demonstration
class MyModule(torch.nn.Module):
    def forward(self, x, y):
        return torch.add(x, y)

import torch
import torch.fx as fx

def transform(m: torch.nn.Module,
              tracer_class : type = fx.Tracer) -> torch.nn.Module:
    graph = tracer_class().trace(m)
    # FX represents its Graph as an ordered list of
    # nodes, so we can iterate through them.
    for node in graph.nodes:
        # Checks if we're calling a function (i.e:
        # torch.add)
        if node.op == 'call_function':
            # The target attribute is the function
            # that call_function calls.
            if node.target == torch.add:
                node.target = torch.mul

    graph.lint() # Does some checks to make sure the
                 # Graph is well-formed.
    return fx.GraphModule(m, graph)

m = MyModule()
newModel = transform(m)
print(newModel)
  1. 新建算子并替换
import torch
from torch.fx import symbolic_trace
import operator

# 自定义模型
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y, torch.add(x, y), x.add(y)
# fx符号跟踪获取计算图
traced = symbolic_trace(M())
print(traced.graph)

# 定义 计算操作 目标模式
patterns = set([operator.add, torch.add, "add"])
# 分别对应上述的+ torch.add x.add, 前两个为call_function,后一个为 call_method类型节点

def add_2_bitwise_and(gm):
    # 1. 遍历图中的所有节点
    for n in gm.graph.nodes:
        # 2. 判断 n.target,匹配目标模式
        if any(n.target == pattern for pattern in patterns):
            # 3. 设置新节点的插入点
            with gm.graph.inserting_after(n): # inserting_before
                # 4. 利用原节点的参数创建新的 call_function 节点,运算类型为 torch.bitwise_and
                new_node = gm.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
                # 5. 替换所有使用原节点的节点的输入为新节点
                n.replace_all_uses_with(new_node)
            # 6. 删除被替换的节点
            gm.graph.erase_node(n)
    # 7. fx计算图编译,更新生成的python代码,及其可执行函数
    gm.recompile()
# 修改计算图
add_2_bitwise_and(traced)
print(traced.graph)

print(traced.code) # 打印更新后的计算图反向编译产生的python源代码
  1. 子图替换
import torch
from torch.fx import symbolic_trace

class M(torch.nn.Module):
    def forward(self, x):
        val = torch.neg(x) + torch.relu(x)
        return torch.add(val, val)
# 使用fx符号追踪获取模型结构
traced = symbolic_trace(M())
# 查看原计算图
print(traced.graph)

# 打印反向翻译的python代码
print(traced.code)

# 定义 待匹配的图 原结构模式
def pattern(x):
    return torch.neg(x) + torch.relu(x)

# 定义 目标结构
def replacement(x):
    return torch.neg(torch.clamp(x,max=0))

# 使用 目标结构 replacement 替换 原结构模式 pattern
torch.fx.subgraph_rewriter.replace_pattern(traced, pattern, replacement)
# 打印更新后的计算图
print(traced.graph)

# 打印反向翻译的python代码
print(traced.code)


# 可以看出该子图替换效果为 将neg(x) + relu(x) 替换为 neg(clamp(x,max=0))
# 真值对比函数
def comparison(x):
    val = torch.neg(torch.clamp(x, max=0))
    return torch.add(val, val)
# 比较测试
comparison_fn = symbolic_trace(comparison)
x = torch.rand(1, 3)
ref_output = comparison_fn(x)   # 对比函数执行结果
test_output = traced.forward(x) # 更新后的编译子图执行结果
print(torch.max(ref_output-test_output)) # 打印误差 tensor(0.)
  1. fx 模型量化
torch.backends.quantized.engine = 'fbgemm'  # 设置量化后端  cpu设置
qconfig_mapping = get_default_qconfig_mapping("fbgemm")  
# 校准量化精度
model_to_quantize = copy.deepcopy(model)
prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs=torch.randn([1, 3, 224, 224]))  #准备工作
prepared_model.eval()
with torch.inference_mode():
    for inputs, labels in test_dataloader:
        prepared_model(inputs)            #量化校对

quantized_recover_model = convert_fx(prepared_model)  #转换
import os
import copy
import time

import torch
from torch import nn

import torchvision
from torchvision import transforms
from torchvision.models.resnet import resnet18

from torch.quantization.quantize_fx import prepare_fx, convert_fx
from torch.ao.quantization import get_default_qconfig_mapping

from torch.utils.data import DataLoader

def test(model, test_dataloader, device):
    best_acc = 0
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    test_acc = 0
    with torch.no_grad():  # 设置禁止计算梯度
        for batch_idx, (inputs, targets) in enumerate(test_dataloader):  # 从DataLoader获取数据
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)  # 前向传播
            criterion = nn.CrossEntropyLoss()
            loss = criterion(outputs, targets)  # 计算交叉熵损失

            test_loss += loss.item()  # item() 获取标量的数值
            _, predicted = outputs.max(1)  # 返回第1个维度上的最大的(元素值,索引) predicted为每个样本预测的分类Id
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        test_acc = correct / total
        print('[INFO] Test Accurancy: {:.3f}'.format(test_acc), '\n')


def print_size_of_model(model):
    torch.save(model.state_dict(), "tmp.pt")
    print(f"The model size:{os.path.getsize('tmp.pt') / 1e6}MB")

model = resnet18(pretrained=True)
# 修改模型
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()  # type: ignore
model.fc = nn.Linear(model.fc.in_features, 10)
# model.load_state_dict(
#     torch.load("C:\\Users\\tanfengfeng\\.cache\\torch\\hub\\checkpoints\\resnet18_cifar10.pth", map_location='cpu'))
# model.to(torch.device("cpu"))
model.eval()

# 设置mean和scale
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# 准备数据集
train_data = torchvision.datasets.CIFAR10(root='data', train=True, transform=torchvision.transforms.ToTensor(),download=True)
test_data = torchvision.datasets.CIFAR10(root='data', train=False, transform=transform_test, download=True)

print("训练集的长度:{}".format(len(train_data)))
print("测试集的长度:{}".format(len(test_data)))

# DataLoader加载数据集
train_dataloader = DataLoader(train_data, batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)

#模型量化
torch.backends.quantized.engine = 'fbgemm'  # 设置量化后端
qconfig_mapping = get_default_qconfig_mapping("fbgemm")
# 校准量化精度
model_to_quantize = copy.deepcopy(model)
prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs=torch.randn([1, 3, 224, 224]))
prepared_model.eval()
with torch.inference_mode():
    for inputs, labels in test_dataloader:
        prepared_model(inputs)

quantized_recover_model = convert_fx(prepared_model)

# print(f"quantized model {quantized_recover_model.graph.print_tabular()}")
script_module = torch.jit.trace(quantized_recover_model, example_inputs=torch.randn([1, 3, 224, 224]))
torch.jit.save(script_module, "quant_model.pth")
# print_size_of_model(prepared_model)
# print_size_of_model(quantized_recover_model)


#测试FP32模型精度和耗时
with torch.autograd.profiler.profile(enabled=True, use_cuda=False, record_shapes=False, profile_memory=False) as prof:
    test(model, test_dataloader, device='cpu')
print(prof.table())


#测试int8模型精度和耗时
quantized_recover_model = torch.jit.load("quant_model.pth")
with torch.autograd.profiler.profile(enabled=True, use_cuda=False, record_shapes=False, profile_memory=False) as prof:
    test(quantized_recover_model, test_dataloader, device='cpu')
print(prof.table())
  1. 接口解读
    fx 量化接口在这里插入图片描述
    参数配置
    在这里插入图片描述
  • 6
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值