Pytorch版的Efficientnet训练自己的数据集

使用Efficientnet进行图像分类:从数据准备到训练与测试
本文介绍了如何利用Efficientnet进行图像分类任务,包括环境配置、数据预处理(切分训练集与验证集、图片补边)、模型训练(预训练模型加载、数据读取、学习率衰减策略)和测试过程。作者展示了完整的代码实现,并分享了训练和测试的结果。


前言

  最近,自己需要一个分类网络来完成一项任务,于是便想起了身边人推荐过的Efficientnet,据说效果是较为稳定的,所以自己来一探究竟,示例的话就用个最简单的二分类吧。


一、环境搭建

本人使用的环境为:
python3.6
torch=1.5
torchvision =0.6.0
opencv-python=4.5.1.48

以上这些仅供参考,无需一致,重要的使我们还需要安装pytorch集合进来的Efficientnet模块,在我们要使用的python环境下,执行命令

pip install efficientnet_pytorch

其他依赖项到时逐个安装即可。


二、数据准备

1.数据摆放

  原始数据摆放如下:

在这里插入图片描述

  也就是以类别名来命名文件夹名,将对应的类别图片放置对应的文件夹下,一般来说,分类任务的数据集大多都是这样来摆放的。

2、训练集和验证集切分

  这一步只需要运行dataset.py即可,它会按照我们制定的比例将我们的数据集进行切分开,同时,为了减少直接resize带来的图片变形的弊端,这里在切分的同时我对数据还进行补边的操作,也就是将数据尽量变为正方形的样子,代码如下

#为efficientnet训练分类的数据进行预处理(训练集切分+补边)
import os
import glob
import cv2
import random
from pathlib import Path


#补边,这一步主要是为了将图片填充为正方形,防止直接resize导致图片变形
def expend_img(img):
    '''
    :param img: 图片数据
    :return:
    '''
    fill_pix=[122,122,122] #填充色素,可自己设定
    h,w=img.shape[:2]
    if h>=w: #左右填充
        padd_width=int(h-w)//2
        padd_top,padd_bottom,padd_left,padd_right=0,0,padd_width,padd_width #各个方向的填充像素
    elif h<w: #上下填充
        padd_high=int(w-h)//2
        padd_top,padd_bottom,padd_left,padd_right=padd_high,padd_high,0,0 #各个方向的填充像素
    new_img = cv2.copyMakeBorder(img,padd_top,padd_bottom,padd_left,padd_right,cv2.BORDER_CONSTANT, value=fill_pix)
    return new_img


#切分训练集和测试集,并进行补边处理
def split_train_test(img_dir,save_dir,train_val_num):
    '''
    :param img_dir: 原始图片路径,注意是所有类别所在文件夹的上一级目录
    :param save_dir: 保存图片路径
    :param train_val_num: 切分比例
    :return:
    '''
    img_dir_list=glob.glob(img_dir+os.sep+"*")#获取每个类别所在的路径(一个类别对应一个文件夹)
    for class_dir in img_dir_list:
        class_name=class_dir.split(os.sep)[-1] #获取当前类别
        img_list=glob.glob(class_dir+os.sep+"*") #获取每个类别文件夹下的所有图片
        all_num=len(img_list) #获取总个数
        train_list=random.sample(img_list,int(all_num*train_val_num)) #训练集图片所在路径
        save_train=save_dir+os.sep+"train"+os.sep+class_name
        save_val=save_dir+os.sep+"val"+os.sep+class_name
        os.makedirs(save_train,exist_ok=True)
        os.makedirs(save_val,exist_ok=True) #建立对应的文件夹
        print(class_name+" trian num",len(train_list))
        print(class_name+" val num",all_num-len(train_list))
        #保存切分好的数据集
        for imgpath in img_list:
            imgname=Path(imgpath).name #获取文件名
            if imgpath in train_list:
                img=cv2.imread(imgpath)
                new_img=expend_img(img)
                cv2.imwrite(save_train+os.sep+imgname,new_img)
            else: #将除了训练集意外的数据均视为验证集
                img = cv2.imread(imgpath)
                new_img = expend_img(img)
                cv2.imwrite(save_val + os.sep + imgname, new_img)

    print("split train and val finished !")

这里,也对代码内容和相关参数进行了注释,理解起来应该不是很难
运行它的时候,我们只需要调用split_train_test()函数,输入指定的参数(3个)即可,需要注意的是,这里给的原始图片的路径是所有类别文件夹的上一级,程序会依次遍历它下面的各个文件夹来进行切分运行完成后,会生成对应的训练集和测试集,如下图:
在这里插入图片描述
train和val里也会有生成各个类别的文件夹用于储存不同类别的数据,需要注意的是,这里存放的数据是我经过补边之后的,对原路径的数据集不会有改动,填充颜色我默认设置为了灰色,可以根据自己爱好在代码中自行更改

三、训练

1.预训练模型下载

这里可以对代码进行更改,使它自动下载模型,我是觉得慢,所以手动下载了,网址如下:

'''
efficientnet-b0: https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b0-355c32eb.pth
efficientnet-b1: https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b1-f1951068.pth
efficientnet-b2: https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b2-8bb594d6.pth
efficientnet-b3: https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b3-5fb5a3c3.pth
efficientnet-b4: https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b4-6ed6700e.pth
efficientnet-b5: https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b5-b6417697.pth
efficientnet-b6: https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b6-c76e70fd.pth
efficientnet-b7: https://github.com/lukemelas/EfficientNet-PyTorch/releases/download/1.0/efficientnet-b7-dcc49843.pth
'''

我选用的是b0

2.加载模型

代码如下(示例):

        base_model = EfficientNet.from_name('efficientnet-b0') #加载模型,使用b几的就改为b几
        state_dict = torch.load(self.weights)
        base_model.load_state_dict(state_dict)
        # 修改全连接层
        num_ftrs = base_model._fc.in_features
        base_model._fc = nn.Linear(num_ftrs, self.class_num)
        self.model = base_model.to(device)

3.数据读取部分

这里对数据进行了指定的数据变换(增强),可以根据需求进行删改,代码如下:

	#数据处理
    def process(self):
        # 数据增强
        data_transforms = {
   
   
            'train': transforms.Compose([
                transforms.Resize((self.imgsz, self.imgsz)),  # resize
                transforms.CenterCrop((self.imgsz, self.
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值