resnet模型 图像分类 图像识别 权重导出 pt模型转化为onnx格式
最近在处理边缘计算产品模型部署时遇到一些麻烦,由于产品不支持yolov分类模型部署,于是选择resnet网络用pytorch重新训练,同时需要将训练好的权重导出,并转化为onnx通用格式,参考了网上大佬的代码后整合了下,在此记录一下,日后用~~~
以下是数据集目录结构(非常简单),将需要分类的图片存放到对应目录里,train与val比例为8:2,very easy~~~
完整模型训练代码如下,包括数据加载+模型训练+保存权重+格式转化:
import os
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import models, datasets, transforms
import torch.utils.data as tud
import numpy as np
from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler
from PIL import Image
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings("ignore")
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')
n_classes = 5 # 几种分类的
preteain = False # 是否下载使用训练参数 有网true 没网false
epoches = 20 # 训练的轮次
traindataset = datasets.ImageFolder(root='./my_datasets/train/', transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
testdataset = datasets.ImageFolder(root='./my_datasets/val/', transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]))
classes = testdataset.classes
#print(classes)
model = models.resnet18(pretrained=preteain)
if preteain == True:
for param in model.parameters():
param.requires_grad = False
model.fc = nn.Linear(in_features=512, out_features=n_classes, bias=True)
model = model.to(device)
def train_model(model, train_loader, loss_fn, optimizer, epoch):
model.train()
total_loss = 0.
total_corrects = 0.
total = 0.
for idx, (inputs, labels) in enumerate(train_loader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
preds = outputs.argmax(dim=1)
total_corrects += torch.sum(preds.eq(labels))
total_loss += loss.item() * inputs.size(0)
total += labels.size(0)
total_loss = total_loss / total
acc = 100 * total_corrects / total
print("echo:%4d, 损失loss:%.5f, 准确率:%6.2f%%" % (epoch + 1, total_loss, acc))
return total_loss, acc
def test_model(model, test_loader, loss_fn, optimizer, epoch):
model.train()
total_loss = 0.
total_corrects = 0.
total = 0.
accuracy_tmp = -1
with torch.no_grad():
for idx, (inputs, labels) in enumerate(test_loader):
inputs = inputs.to(device)
labels = labels.to(device)
outputs = model(inputs)
loss = loss_fn(outputs, labels)
preds = outputs.argmax(dim=1)
total += labels.size(0)
total_loss += loss.item() * inputs.size(0)
total_corrects += torch.sum(preds.eq(labels))
loss = total_loss / total
accuracy = 100 * total_corrects / total
torch.save(model.state_dict(),'last.pt')
if accuracy > accuracy_tmp :
accuracy_tmp = accuracy
torch.save(model.state_dict(),'best.pt')
print("echo:%4d, 损失loss:%.5f, 准确率:%6.2f%%" % (epoch + 1, loss, accuracy))
return loss, accuracy
loss_fn = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.0001)
train_loader = DataLoader(traindataset, batch_size=32, shuffle=True)
test_loader = DataLoader(testdataset, batch_size=32, shuffle=True)
for epoch in range(0, epoches):
loss1, acc1 = train_model(model, train_loader, loss_fn, optimizer, epoch)
loss2, acc2 = test_model(model, test_loader, loss_fn, optimizer, epoch)
classes = testdataset.classes
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
input_tensor = torch.randn(1,3,224,224)
#将pt格式转为onnx格式
torch.onnx.export(model,input_tensor,"last.onnx")
path = './lifecycle/val/yourpath/37.png' # 测试图片路径
model.eval()
img = Image.open(path)
img_p = transform(img).unsqueeze(0).to(device)
output = model(img_p)
pred = output.argmax(dim=1).item()
plt.imshow(img)
plt.show()
p = 100 * nn.Softmax(dim=1)(output).detach().cpu().numpy()[0]
print('该图像预测类别为:', classes[pred])
执行代码:
python torch_resnet.py
输出效果: