复现resnet18花分类,并计算torch/onnx/mnn fp32和int8精度

【pytorch花分类】使用torchvision的resnet18

提前下载好数据集并且分割好训练和验证集

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms

# 定义数据的路径和预处理方法
data_dir = '/home/ruoji/MNN/data/flowers'
train_dir = data_dir + '/train'
val_dir = data_dir + '/val'

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# 加载数据
train_data = datasets.ImageFolder(train_dir, data_transforms['train'])
val_data = datasets.ImageFolder(val_dir, data_transforms['val'])
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=4)

# 加载模型
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5)  # 将全连接层改为5分类
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

num_epochs = 10
scheduler = lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

best_acc = 0.0
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
    scheduler.step()

    model.eval()
    num_correct = 0
    num_total = 0
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        _, predicted = torch.max(outputs, 1)
        num_correct += (predicted == labels).sum().item()
        num_total += labels.size(0)

    epoch_loss = running_loss / len(train_data)
    epoch_acc = num_correct / num_total
    print('Epoch {} - Loss: {:.4f} Acc: {:.4f}'.format(epoch+1, epoch_loss, epoch_acc))
    if epoch_acc > best_acc:
        best_acc = epoch_acc
        PATH = '/home/ruoji/MNN/data/resnet18_flowers_best.pth'
        torch.save(model.state_dict(), PATH)

print('Finished Training')

导出onnx模型

import torch
import torchvision

# 加载训练好的模型
model = torchvision.models.resnet18(pretrained=False)
model.fc = torch.nn.Linear(512, 5)
model.load_state_dict(torch.load('/home/ruoji/MNN/data/resnet18_flowers.pth'))

# 设置输入张量的形状
batch_size = 1
input = torch.randn(batch_size, 3, 224, 224)

input_name = 'input'
output_name = 'output'
# 将模型转换为ONNX格式
torch.onnx.export(model, input, 
                  'resnet18_flower_onnx.onnx', 
                  input_names = [input_name],
                  output_names = [output_name],
                  verbose=True,
                  opset_version=11,
                  dynamic_axes={input_name: {0: 'batch_size'},
                                output_name: {0: 'batch_size'}})

onnx转mnn int量化

mnnconvert -f ONNX --bizCode MNN --modelFile resnet18_flower_onnx.onnx --MNNModel resnet18_flower_fp32.mnn --keepInputFormat


mnnquant resnet18_flower_fp32.mnn resnet18_flower_int8.mnn quant_flower.json 
{
    "format":"RGB",
    "mean":[
        103.94,
        116.78,
        123.68
    ],
    "normal":[
        0.017,
        0.017,
        0.017
    ],
    "width":224,
    "height":224,
    "path":"/home/ruoji/MNN/data/flowers/val/daisy",
    "used_image_num":50,
    "feature_quantize_method":"KL",
    "weight_quantize_method":"MAX_ABS"
}

测试精度

pth: 0.9478
在这里插入图片描述

onnx : 0.9341
在这里插入图片描述

mnn_fp32: 0.9341
在这里插入图片描述
mnn_int8: 0.9341 离线量化
在这里插入图片描述

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值