使用MobileOne完成图像分类任务

记录尝试用MobileOne模型做自己分类任务的过程,使用的模型的代码是官方给出的代码:

GitHub - apple/ml-mobileone: This repository contains the official implementation of the research paper, "An Improved One millisecond Mobile Backbone".icon-default.png?t=N7T8https://github.com/apple/ml-mobileone里面包含了mobileone.py以及每个模型的预训练参数。

1、数据集和预处理

首先关于数据集构成如下:

数据集的读取

image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x),
                                          data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
                                             shuffle=(x=='train'), num_workers=args.num_work)
              for x in ['train', 'val']}

#分别获取训练集和测试集的图像总数
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

#每一类别的名称
class_names = image_datasets['train'].classes

#判断训练集和测试集类别是否相同
if len(image_datasets['train'].classes) != len(image_datasets['val'].classes):
    print("DataSet Error!")
    exit(-1)

num_classes = len(image_datasets['val'].classes)

图像的预处理我还是按照自己之前训练的流程来的,当然也可以添加一些其他的增强方法。

class MyRotationTransform:
    """Rotate by one of the given angles."""
    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)


data_transforms = {
    'train': transforms.Compose([
        letterbox(image_size),
        # transforms.Resize(image_size),
        MyRotationTransform(angles=[90, 180, 270]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        letterbox(image_size),
        # transforms.Resize(image_size),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

为了防止resize方法造成图像的变形,我在这里使用了yolo里面的灰度填充的方法letterbox,代码如下:

import cv2
import numpy as np
from PIL import Image


class letterbox():
    def __init__(self, shape=(1600, 1600), color=(114, 114, 114), auto=False,
                 scaleFill=False, scaleup=True, stride=32):

        self.auto = auto
        self.scaleFill = scaleFill
        self.scaleup = scaleup
        self.stride =stride
        self.color =color
        self.new_shape = shape

    def __call__(self, x):
        im = cv2.cvtColor(np.asarray(x), cv2.COLOR_RGB2BGR)
        shape = im.shape[:2]  #获取原始图像的尺寸
        if isinstance(shape, int):   # 判断new_shape是否为整数
            self.new_shape = (shape, shape)   # 是整数则将new_shape转换为二维元组

        # Scale ratio (new / old)
        r = min(self.new_shape[0] / shape[0], self.new_shape[1] / shape[1])
        if not self.scaleup:  # only scale down, do not scale up (for better val mAP)
            r = min(r, 1.0)

        # Compute padding
        ratio = r, r  # width, height ratios
        new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
        dw, dh = self.new_shape[1] - new_unpad[0], self.new_shape[0] - new_unpad[1]  # wh padding
        if self.auto:  # minimum rectangle
            dw, dh = np.mod(dw, self.stride), np.mod(dh, self.stride)  # wh padding
        elif self.scaleFill:  # stretch
            dw, dh = 0.0, 0.0
            new_unpad = (self.new_shape[1], self.new_shape[0])
            ratio = self.new_shape[1] / shape[1], self.new_shape[0] / shape[0]  # width, height ratios

        dw /= 2  # divide padding into 2 sides
        dh /= 2

        if shape[::-1] != new_unpad:  # resize
            im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_NEAREST)
            # im = cv2.resize(im, new_unpad, interpolation=cv2.INTER_LINEAR)
        top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))
        left, right = int(round(dw - 0.1)), int(round(dw + 0.1))
        im = cv2.copyMakeBorder(im, top, bottom, left, right, cv2.BORDER_CONSTANT, value=self.color)  # add border
        im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
        return im

if __name__ == '__main__':
    img = Image.open('../image/0_06145527019_31.bmp')
    print("img  ",img.size)
    img.show()
    cv2.waitKey(0)
    shape = [224,224]
    model = letterbox(shape)
    out = model(img)
    print("out  ",out.size)
    out.show()
    cv2.waitKey(0)

2、训练:

训练过程中的模型加载都是根据官方给出的代码来操作的,我这里使用的是s1模型,我的完整的训练代码如下:

from __future__ import print_function, division
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
from torchvision import datasets, transforms
from torchsummary import summary
from model.mobileOne import *
from utils.augment import letterbox
import os
import copy
import torchvision.transforms.functional as TF
import random
from torch.utils.tensorboard import SummaryWriter

parser = argparse.ArgumentParser()
parser.add_argument('--batch_size', default=32, type=int)
parser.add_argument('--epoch', default=70, type=int, help='number of total epochs to run')
parser.add_argument('--data_dir', default=r'D:/image/dataset_name/',type=str)
parser.add_argument('--lr', default=0.001, type=int, help='learn rate')
parser.add_argument('--momentum', default=0.9, type=float)
parser.add_argument('--step_size', default=10, type=int, help='step_size for scheduler')
parser.add_argument('--gamma', default=0.5, type=float, help='update the multiplication factor of lr')
parser.add_argument('--pretrained', dest='pretrained', action='store_true', default=True, help='use pre-trained model')
parser.add_argument('--num_work', default=8, type=int)
parser.add_argument('--save_path',default='../weights/mobileone.pth',type=str)
parser.add_argument('--save_log',default='mobileOne_s1.log', type=str)

global args
args = parser.parse_args()

#############################################################################################
image_size=[224, 224]
device = torch.device("cuda:0")

#用来记录每一轮训练的损失和分类精度
if not os.path.exists(args.save_log):
    with open(args.save_log, "w") as f:
        pass

import time
date = time.strftime("%Y-%m-%d-%H:%M:%S", time.localtime())
prefix = date+'/'
writer = SummaryWriter('log/')

class MyRotationTransform:
    """Rotate by one of the given angles."""
    def __init__(self, angles):
        self.angles = angles

    def __call__(self, x):
        angle = random.choice(self.angles)
        return TF.rotate(x, angle)

# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
    'train': transforms.Compose([
        letterbox(image_size),
        # transforms.Resize(image_size),
        MyRotationTransform(angles=[90, 180, 270]),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        letterbox(image_size),
        # transforms.Resize(image_size),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

##################################### dataset #############################################

image_datasets = {x: datasets.ImageFolder(os.path.join(args.data_dir, x),data_transforms[x])
                  for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=args.batch_size,
                                             shuffle=(x=='train'), num_workers=args.num_work)
              for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
if len(image_datasets['train'].classes) != len(image_datasets['val'].classes):
    print("DataSet Error!")
    exit(-1)

num_classes = len(image_datasets['val'].classes)

def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_acc = 0.0

    for epoch in range(num_epochs):
        t1 = time.time()
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        localtime = time.asctime(time.localtime(time.time()))
        print(localtime)
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0
            class_correct = list(0. for i in range(num_classes))
            class_total = list(0. for i in range(num_classes))

            # Iterate over data.
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                c = torch.eq(preds, labels.to(device)).squeeze()
                size = int(labels.shape[0])
                for i in range(size):
                    label = labels[i]
                    class_correct[label] += c[i].item()
                    class_total[label] += 1

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]
            
            # 输出每一个类别的分类精度
            for i in range(num_classes):
                print('Acc of %5s : %4f %%' % (
                    class_image[i], 100 * class_correct[i] / class_total[i]))
                with open(args.save_log, 'a') as f:
                    f.write(' {} Acc: {:.4f}\n'.format(class_image[i], 100 * class_correct[i] / class_total[i]))

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(
                phase, epoch_loss, epoch_acc))
            print()

            # deep copy the model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model, args.save_path)

            with open(args.save_log, 'a') as f:
                f.write('Epoch {}, {} Loss: {:.4f} Acc: {:.4f}\n'.format(epoch, phase, epoch_loss, epoch_acc))
                f.write('best: {:.4f}\n'.format(best_acc))
                f.write('\n')

        print("best : ",best_acc)
        print("time = ",time.time()-t1)
        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:4f}'.format(best_acc))

    return model

if __name__ == '__main__':

    model_ft = mobileone(variant='s1', num_classes=num_classes)
    checkpoint = torch.load('D:\mobileOne\pretrained\mobileone_s1_unfused.pth.tar')
    #去除最后的linear层,修改类别为自己数据集的类别数
    checkpoint.pop('linear.weight')
    checkpoint.pop('linear.bias')
    model_ft.load_state_dict(checkpoint, strict=False)

    model_ft = model_ft.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer_ft = optim.SGD(model_ft.parameters(), lr=args.lr, momentum=args.momentum)
    exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=args.step_size, gamma=args.gamma)

    model = train_model(model_ft, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=args.epoch)

3、重参化

训练完成之后可以得到一个mobileone.pth模型,然后对其进行重参化:

import copy
from torch import nn
from model.mobileOne import mobileone
import torch

model = torch.load('../weights/mobileone.pth')

def reparameterize_model(model: torch.nn.Module) -> nn.Module:
    """ Method returns a model where a multi-branched structure
        used in training is re-parameterized into a single branch
        for inference.

    :param model: MobileOne model in train mode.
    :return: MobileOne model in inference mode.
    """
    # Avoid editing original graph
    model = copy.deepcopy(model)
    for module in model.modules():
        if hasattr(module, 'reparameterize'):
            module.reparameterize()
    return model

model_rep = reparameterize_model(model)
torch.save(model_rep, 'model_rep.pth')

整个的训练过程就结束了,可以直接用重参数之后的模型做推理。

4、总结

       在我自己的图像上的训练结果并不是太好,我使用letterbox做填充的mobilenetv2最后在测试集上的分类精度是99.43%,用s1的分类精度是99.38%,当然对于不用的数据集肯定适用度也不同。最后测试推理速度时发现,推理的速度确实和mobileNetv2的速度也相差无几,和论文里的情况相符。

评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值