模型量化(二)—— 训练后量化PTQ(全代码)

13 篇文章 1 订阅
8 篇文章 27 订阅
文章详细介绍了训练后量化(PTQ)技术,包括其在模型部署中的优势,如何处理激活值的量化,以及提供了一个使用PyTorch实现的简单神经网络模型的量化过程,展示了从预训练模型到量化模型的完整流程,包括观察者机制、校准和量化步骤以及量化后模型的大小和精度变化。
摘要由CSDN通过智能技术生成

训练后量化(Post-training Quantization,PTQ)是一种常见的模型量化技术,它在模型训练完成之后应用,旨在减少模型的大小和提高推理速度,同时尽量保持模型的性能。训练后量化对于部署到资源受限的设备上,如移动设备和嵌入式设备,特别有用。

在我们量化时,量化操作可以应用于模型的输入、权重 和 激活(即神经元输出值)上。

但我们发现,对于激活值,我们执行反量化时,并不知道这些激活值对应的浮点数矩阵的最大值和最小值,即我们执行非对称或对称量化里面的 𝛼, β 参数,所以我们拿到一个模型时,最多只能对它的权重W和输入X做量化,对于激活值Y的反量化,我们需要一组小的calibration set数据来初步计算对于Y的S和Z参数。

不熟悉非对称或对称量化的朋友可以康康这篇:《模型量化(一)—— 非对称量化、对称量化(全代码)》

在这里插入图片描述
 

 

PTQ流程:

在这里插入图片描述
Observer,顾名思义就是模型在正常inference的时候会被记录下正常的浮点激活值,用来算激活值对应的S和Z参数。

Calibrate后模型的W和Y都有对应的S和Z了,模型名义上量化完成。浮点的输入X也能off-line地实时算它对应的S和Z。

所以量化后的模型运行时,先对浮点输入进行量化,然后与整型的W矩阵相乘,得到整型的激活值,这时再反量化为浮点激活值,对应于下一个神经元的浮点输入,依次循环。
大家可能会想吗,这么麻烦,又是量化又是反量化,怎么还会压缩模型和加速模型呢?

压缩模型:原本所有的W都是浮点数存储,比如float32,现在转换为int8存储,模型尺寸减了大概4倍;再额外存一些神经元或网络层的S和Z参数(取决于量化的粗粒度),相对于W来说占内存很小(如果是很细粒度的量化可能这部分也得好好考虑,量化的粒度分为权重级量化、层级量化、通道级量化等)。

加速模型:主要的收益是使得模型中占大头的 W * X 操作变成了整型相乘,功耗和时延最低(浮点数相乘时功耗和时延最大)。3 * 100 * 100 * 10的全连接网络中,有213个神经元,但是有 3 * 100 * 100 * 10 = 300M个参数!这还是忽略了bias。量化相当于就是让这 300M 次乘法更轻量。而相对的 overhead 就是对开头的3 + 100 + 100 = 203个中间输入进行一下量化 和 对 100 + 100 + 10 = 210个激活值进行一下反量化,这部分开销随着网络层数与参数的增加几乎可以忽略不计。
一些专门的深度学习加速器和现代CPU/GPU提供了对低位宽整数(如int8)的优化支持,用这些硬件后可以更加体现模型量化的优势。

量化会带来一定的量化误差,即模型精度会受影响,这肯定的,但按经验来说几乎没什么影响,不要压到int4或int2这么极限就行。

 

全代码

预训练模型

import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

# Make torch deterministic
_ = torch.manual_seed(0)

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"


# Define the model
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x

net = VerySimpleNet().to(device)


# Train the model
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=1)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)


# Define the testing loop
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0
    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')


# Print weights and size of the model before quantization

# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

print('Size of the model before quantization')
print_size_of_model(net)

print(f'Accuracy of the model before quantization: ')
test(net)

加入Observer

# Insert min-max observers in the model

class QuantizedVerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedVerySimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

net_quantized = QuantizedVerySimpleNet().to(device)
# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

校准模型

#用测试集再跑一次装了observer的模型
test(net_quantized)

print(f'Check statistics of the various layers')
net_quantized

在这里插入图片描述
这时看到激活层的 𝛼, β 都有了,good!

量化模型

# Quantize the model using the statistics collected

net_quantized = torch.ao.quantization.convert(net_quantized)

print(f'Check statistics of the various layers')
net_quantized

在这里插入图片描述

# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))


# Compare the dequantized weights and the original weights
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

# Print size and accuracy of the quantized model
print('Size of the model after quantization')
print_size_of_model(net_quantized)
print('Testing the model after quantization')
test(net_quantized)
  • 19
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 6
    评论
PyTorch支持通过量化技术来压缩模型,减小模型大小和内存占用,并提高模型的推理性能。其中,PTQ(Post Training Quantization)是一种常见的量化方法,它可以在训练后对模型进行量化PTQ的基本思路是将原始模型中的浮点数参数转化为固定位宽的整数,从而减小模型的大小和内存占用,提高模型在嵌入式设备上的推理速度。在PTQ中,可以对权重、激活值、梯度等进行量化。 下面是使用PyTorch进行PTQ的基本流程: 1. 定义模型 首先需要定义一个PyTorch模型。 2. 定义量化方法 接下来需要定义量化方法。PyTorch提供了一些量化方法,可以根据实际需求进行选择。例如,可以使用torch.quantization.quantize_dynamic()方法进行动态量化,或者使用torch.quantization.quantize_static()方法进行静态量化。 3. 对模型进行量化 使用定义的量化方法对模型进行量化,将浮点数参数转化为整数参数。可以使用torch.quantization.prepare()方法对模型进行准备,使用torch.quantization.convert()方法进行转换。 4. 测试量化后的模型 量化完成后,需要测试量化后的模型,确保准确性没有明显下降。 下面是一个简单的示例代码,演示了如何使用PyTorch进行PTQ: ```python import torch import torch.nn as nn import torch.optim as optim import torchvision import torchvision.transforms as transforms from torchvision.models import resnet18 from torch.utils.data import DataLoader # 定义模型 model = resnet18() # 定义数据预处理 transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]) # 加载数据集 trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform) trainloader = DataLoader(trainset, batch_size=128, shuffle=True) # 定义损失函数和优化器 criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9) # 训练模型 for epoch in range(5): running_loss = 0.0 for i, data in enumerate(trainloader, 0): inputs, labels = data optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, labels) loss.backward() optimizer.step() running_loss += loss.item() print('[Epoch %d] loss: %.3f' % (epoch + 1, running_loss / len(trainloader))) # 定义量化方法 quantization_method = torch.quantization.quantize_dynamic # 对模型进行量化 model.qconfig = torch.quantization.get_default_qconfig('fbgemm') quantized_model = quantization_method(model, qconfig_spec={nn.Linear}, dtype=torch.qint8) # 测试量化后的模型 quantized_model.eval() testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform) testloader = DataLoader(testset, batch_size=128, shuffle=False) correct = 0 total = 0 with torch.no_grad(): for data in testloader: images, labels = data outputs = quantized_model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print('Accuracy of the network on the test images: %d %%' % (100 * correct / total)) ``` 注意:PTQ可能会对模型的准确性产生一定的影响,因此需要根据实际情况进行调整。同时,PTQ的效果也受到数据集的影响,因此需要在实际应用中进行测试和优化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

全栈O-Jay

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值