记录尝试用MobileOne模型做自己分类任务的过程,使用的模型的代码是官方给出的代码:
GitHub - apple/ml-mobileone: This repository contains the official implementation of the research paper, "An Improved One millisecond Mobile Backbone".https://github.com/apple/ml-mobileone里面包含了mobileone.py以及每个模型的预训练参数。
1、数据集和预处理
首先关于数据集构成如下:
数据集的读取:
image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
shuffle=(x=='train'), num_workers=args.num_work)
for x in ['train', 'val']}
#分别获取训练集和测试集的图像总数
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
#每一类别的名称
class_names = image_datasets['train'].classes
#判断训练集和测试集类别是否相同
if len(image_datasets['train'].classes) != len(image_datasets['val'].classes):
print("DataSet Error!")
exit(-1)
num_classes = len(image_datasets['val'].classes)
图像的预处理我还是按照自己之前训练的流程来的,当然也可以添加一些其他的增强方法。
class MyRotationTransform:
"""Rotate by one of the given angles."""
def __init__(self, angles):
self.angles = angles
def __call__(self, x):
angle = random.choice(self.angles)
return TF.rotate(x, angle)
data_transforms = {
'train': transforms.Compose([
letterbox(image_size),
# transforms.Resize(image_size),
MyRotationTransform(angles=[90, 180, 270]),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
letterbox(image_size),
# transforms.Resize(image_size),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
为了防止resize方法造成图像的变形,我在这里使用了yolo里面的灰度填充的方法letterbox,代码如下:
import cv2
import numpy as np
from PIL import Image
class letterbox():
def __init__(self, shape=(1600, 1600), color=(114, 114, 114), auto=False,
scaleFill=False, scaleup=True, stride=32):
self.auto = auto
self.scaleFill = scaleFill
self.scaleup = scaleup
self.stride =stride
self.color =color
self.new_shape = shape
def __call__(self, x):
im = cv2.cvtColor(np.asarray(x), cv2.COLOR_RGB2BGR)
shape = im.shape[:2] #获取原始图像的尺寸
if isinstance(shape, int): # 判断new_shape是否为整数
self.new_shape = (shape, shape) # 是整数则将new_shape转换为二维元组
# Scale ratio (new / old)
r = min(self.new_shape[0] / shape[0], self.new_shape[1] / shape[1])
if not self.scaleup: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
ratio = r, r # width, height ratios
new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
dw, dh = self.new_shape[1] - new_unpad[0], self.new_shape[0] - new_unpad[1] # wh padding
if self.auto: # minimum rectangle
dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride) # wh padding
elif self.scaleFill: # stretch
dw, dh = 0.0, 0.0
new_unpad = (self.new_shape[1], self.new_shape[0])
ratio = self.new_shape[1] / shape[1], self.new_shape[0] / shape[0] # width, height ratios
dw /= 2 # divide padding into 2 sides
dh /= 2
if shape[::-1] != new_unpad: # resize
im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_NEAREST)
# im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.color) # add border
im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
return im
if __name__ == '__main__':
img = Image.open('../image/0_06145527019_31.bmp')
print("img ",img.size)
img.show()
cv2.waitKey(0)
shape = [224,224]
model = letterbox(shape)
out = model(img)
print("out ",out.size)
out.show()
cv2.waitKey(0)
2、训练:
训练过程中的模型加载都是根据官方给出的代码来操作的,我这里使用的是s1模型,我的完整的训练代码如下:
from __future__ import print_function, division
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision import datasets, transforms
from torchsummary import summary
from model.mobileOne import *
from utils.augment import letterbox
import os
import copy
import torchvision.transforms.functional as TF
import random
from torch.utils.tensorboard import SummaryWriter
parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--epoch', default=70, type=int, help='number of total epochs to run')
parser.add_argument('--data_dir', default=r'D:/image/dataset_name/',type=str)
parser.add_argument('--lr', default=0.001, type=int, help='learn rate')
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--step_size', default=10, type=int, help='step_size for scheduler')
parser.add_argument('--gamma', default=0.5, type=float, help='update the multiplication factor of lr')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True, help='use pre-trained model')
parser.add_argument('--num_work', default=8, type=int)
parser.add_argument('--save_path',default='../weights/mobileone.pth',type=str)
parser.add_argument('--save_log',default='mobileOne_s1.log', type=str)
global args
args = parser.parse_args()
#############################################################################################
image_size=[224, 224]
device = torch.device("cuda:0")
#用来记录每一轮训练的损失和分类精度
if not os.path.exists(args.save_log):
with open(args.save_log, "w") as f:
pass
import time
date = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
prefix = date+'/'
writer = SummaryWriter('log/')
class MyRotationTransform:
"""Rotate by one of the given angles."""
def __init__(self, angles):
self.angles = angles
def __call__(self, x):
angle = random.choice(self.angles)
return TF.rotate(x, angle)
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
letterbox(image_size),
# transforms.Resize(image_size),
MyRotationTransform(angles=[90, 180, 270]),
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
letterbox(image_size),
# transforms.Resize(image_size),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
##################################### dataset #############################################
image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x),data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
shuffle=(x=='train'), num_workers=args.num_work)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
if len(image_datasets['train'].classes) != len(image_datasets['val'].classes):
print("DataSet Error!")
exit(-1)
num_classes = len(image_datasets['val'].classes)
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
since = time.time()
best_acc = 0.0
for epoch in range(num_epochs):
t1 = time.time()
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
localtime = time.asctime(time.localtime(time.time()))
print(localtime)
print('-' * 10)
# Each epoch has a training and validation phase
for phase in ['train', 'val']:
if phase == 'train':
model.train() # Set model to training mode
else:
model.eval() # Set model to evaluate mode
running_loss = 0.0
running_corrects = 0
class_correct = list(0. for i in range(num_classes))
class_total = list(0. for i in range(num_classes))
# Iterate over data.
for inputs, labels in dataloaders[phase]:
inputs = inputs.to(device)
labels = labels.to(device)
# zero the parameter gradients
optimizer.zero_grad()
# forward
# track history if only in train
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
# backward + optimize only if in training phase
if phase == 'train':
loss.backward()
optimizer.step()
# statistics
c = torch.eq(preds, labels.to(device)).squeeze()
size = int(labels.shape[0])
for i in range(size):
label = labels[i]
class_correct[label] += c[i].item()
class_total[label] += 1
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
if phase == 'train':
scheduler.step()
epoch_loss = running_loss / dataset_sizes[phase]
epoch_acc = running_corrects.double() / dataset_sizes[phase]
# 输出每一个类别的分类精度
for i in range(num_classes):
print('Acc of %5s : %4f %%' % (
class_image[i], 100 * class_correct[i] / class_total[i]))
with open(args.save_log, 'a') as f:
f.write(' {} Acc: {:.4f}\n'.format(class_image[i], 100 * class_correct[i] / class_total[i]))
print('{} Loss: {:.4f} Acc: {:.4f}'.format(
phase, epoch_loss, epoch_acc))
print()
# deep copy the model
if phase == 'val' and epoch_acc > best_acc:
best_acc = epoch_acc
torch.save(model, args.save_path)
with open(args.save_log, 'a') as f:
f.write('Epoch {}, {} Loss: {:.4f} Acc: {:.4f}\n'.format(epoch, phase, epoch_loss, epoch_acc))
f.write('best: {:.4f}\n'.format(best_acc))
f.write('\n')
print("best : ",best_acc)
print("time = ",time.time()-t1)
print()
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
print('Best val Acc: {:4f}'.format(best_acc))
return model
if __name__ == '__main__':
model_ft = mobileone(variant='s1', num_classes=num_classes)
checkpoint = torch.load('D:\mobileOne\pretrained\mobileone_s1_unfused.pth.tar')
#去除最后的linear层,修改类别为自己数据集的类别数
checkpoint.pop('linear.weight')
checkpoint.pop('linear.bias')
model_ft.load_state_dict(checkpoint, strict=False)
model_ft = model_ft.to(device)
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr, momentum=args.momentum)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=args.step_size, gamma=args.gamma)
model = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=args.epoch)
3、重参化
训练完成之后可以得到一个mobileone.pth模型,然后对其进行重参化:
import copy
from torch import nn
from model.mobileOne import mobileone
import torch
model = torch.load('../weights/mobileone.pth')
def reparameterize_model(model: torch.nn.Module) -> nn.Module:
""" Method returns a model where a multi-branched structure
used in training is re-parameterized into a single branch
for inference.
:param model: MobileOne model in train mode.
:return: MobileOne model in inference mode.
"""
# Avoid editing original graph
model = copy.deepcopy(model)
for module in model.modules():
if hasattr(module, 'reparameterize'):
module.reparameterize()
return model
model_rep = reparameterize_model(model)
torch.save(model_rep, 'model_rep.pth')
整个的训练过程就结束了,可以直接用重参数之后的模型做推理。
4、总结
在我自己的图像上的训练结果并不是太好,我使用letterbox做填充的mobilenetv2最后在测试集上的分类精度是99.43%,用s1的分类精度是99.38%,当然对于不用的数据集肯定适用度也不同。最后测试推理速度时发现,推理的速度确实和mobileNetv2的速度也相差无几,和论文里的情况相符。