EfficientNetV2深度学习记录——代码复现

神经网络/深度学习

第二章 Python机器学习入门之EfficientNetV2的使用



前言

本文主要是复现efficientnetv2网络代码,训练自己的材质分类模型,学习记录下来。
大佬文章:https://blog.csdn.net/qq_37541097/article/details/116933569
大佬的讲解视频:https://www.bilibili.com/video/BV1Xy4y1g74u/?spm_id_from=333.1007.top_right_bar_window_history.content.click&vd_source=b9a1a486cbe5d7fe623135210f75aca8
论文下载地址:https://arxiv.org/abs/2104.00298
原论文提供代码:https://github.com/google/automl/tree/master/efficientnetv2
在这里插入图片描述


提示:以下是本篇文章正文内容,下面案例可供参考

一、EfficientNetV2是什么?

EfficientNetV2是由谷歌提出的一种新型神经网络架构,用于图像分类任务。它在EfficientNet的基础上进行了改进,通过优化模型的结构和训练过程,提高了模型的效率和性能。
EffNetV2-S(21k)(红色曲线)是一个EfficientNetV2家族的模型,使用21k个类别的数据进行预训练。该模型在较短的训练时间内(约0.5TPU天)达到了85%准确率,其准确率之高,模型大小之小,选为这次训练的基础模型(自己的小笔记本是4060labtap,感觉没啥问题)
各大模型的准确率
模型可以去大佬的文章中找到代码链接,再从链接中找到百度网盘的模型下载链接。

二、EfficientNetV2代码的复现

首先给大家看一下整体的目录
文件目录
我这里材质分类分了六种,分别是3D,玻璃,镜面 ,金属,平滑,纹理(当然这个是我人工定义的,大家可以根据自己的需求进行更改)

1.准备工作

在train文件夹下面设置好你所设定的种类,我这里六种,我就设置了六个文件夹并且以种类的名字命名,里面添加好各个种类的图片(图片根据自己的需求添加就行)

2.训练模型

train.py代码

import os
import math
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import efficientnetv2_s as create_model
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device = torch.device(args.device if torch.cuda.is_available() else "cpu")

    print(args)
    print('Start Tensorboard with "tensorboard --logdir=runs", view at http://localhost:6006/')
    tb_writer = SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")

    train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(args.data_path)

    img_size = {"s": [300, 384],  # train_size, val_size
                "m": [384, 480],
                "l": [384, 480]}
    num_model = "s"

    data_transform = {
        "train": transforms.Compose([transforms.RandomResizedCrop(img_size[num_model][0]),
                                     transforms.RandomHorizontalFlip(),
                                     transforms.ToTensor(),
                                     transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]),
        "val": transforms.Compose([transforms.Resize(img_size[num_model][1]),
                                   transforms.CenterCrop(img_size[num_model][1]),
                                   transforms.ToTensor(),
                                   transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])}

    # 实例化训练数据集
    train_dataset = MyDataSet(images_path=train_images_path,
                              images_class=train_images_label,
                              transform=data_transform["train"])

    # 实例化验证数据集
    val_dataset = MyDataSet(images_path=val_images_path,
                            images_class=val_images_label,
                            transform=data_transform["val"])

    batch_size = args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])  # number of workers
    print('Using {} dataloader workers every process'.format(nw))
    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               shuffle=True,
                                               pin_memory=True,
                                               num_workers=nw,
                                               collate_fn=train_dataset.collate_fn)

    val_loader = torch.utils.data.DataLoader(val_dataset,
                                             batch_size=batch_size,
                                             shuffle=False,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=val_dataset.collate_fn)

    # 如果存在预训练权重则载入
    model = create_model(num_classes=args.num_classes).to(device)
    if args.weights != "":
        if os.path.exists(args.weights):
            weights_dict = torch.load(args.weights, map_location=device)
            load_weights_dict = {k: v for k, v in weights_dict.items()
                                 if model.state_dict()[k].numel() == v.numel()}
            print(model.load_state_dict(load_weights_dict, strict=False))
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))

    # 是否冻结权重
    if args.freeze_layers:
        for name, para in model.named_parameters():
            # 除head外,其他权重全部冻结
            if "head" not in name:
                para.requires_grad_(False)
            else:
                print("training {}".format(name))

    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer = optim.SGD(pg, lr=args.lr, momentum=0.9, weight_decay=1E-4)
    # Scheduler https://arxiv.org/pdf/1812.01187.pdf
    lf = lambda x: ((1 + math.cos(x * math.pi / args.epochs)) / 2) * (1 - args.lrf) + args.lrf  # cosine
    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lf)

    for epoch in range(args.epochs):
        # train
        train_loss, train_acc = train_one_epoch(model=model,
                                                optimizer=optimizer,
                                                data_loader=train_loader,
                                                device=device,
                                                epoch=epoch)

        scheduler.step()

        # validate
        val_loss, val_acc = evaluate(model=model,
                                     data_loader=val_loader,
                                     device=device,
                                     epoch=epoch)

        tags = ["train_loss", "train_acc", "val_loss", "val_acc", "learning_rate"]
        tb_writer.add_scalar(tags[0], train_loss, epoch)
        tb_writer.add_scalar(tags[1], train_acc, epoch)
        tb_writer.add_scalar(tags[2], val_loss, epoch)
        tb_writer.add_scalar(tags[3], val_acc, epoch)
        tb_writer.add_scalar(tags[4], optimizer.param_groups[0]["lr"], epoch)

        torch.save(model.state_dict(), "./weights/model-{}.pth".format(epoch))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_classes', type=int, default=6)#训练种类
    parser.add_argument('--epochs', type=int, default=300)#训练轮次
    parser.add_argument('--batch-size', type=int, default=8)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--lrf', type=float, default=0.01)

    # 数据集所在根目录
    parser.add_argument('--data-path', type=str,
                        default="数据集所在位置")
    parser.add_argument('--weights', type=str, default="模型所在位置model/pre_efficientnetv2-s.pth",
                        help='initial weights path')
    parser.add_argument('--freeze-layers', type=bool, default=True)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')

    opt = parser.parse_args()

    main(opt)

其中它会自动生成class_indices.json,可以通过这个来对种类进行观察
材质种类
出现这个界面说明就对了
在这里插入图片描述
训练好的模型会在weigths中显示

3.进行预测

predict.py代码

import os
import json
import torch
from PIL import Image
from torchvision import transforms
from model import efficientnetv2_s as create_model

def main():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    img_size = {"s": [300, 384],  # train_size, val_size
                "m": [384, 480],
                "l": [384, 480]}
    num_model = "s"

    # 载入图片
    img_path = "预测图片位置"
    assert os.path.exists(img_path), "file: '{}' does not exist.".format(img_path)

    # 载入图片并应用转换
    img = Image.open(img_path)

    # 根据图像模式设置转换流程
    if img.mode == 'RGBA':
        # RGBA四通道图像,移除Alpha通道
        data_transform = transforms.Compose([
            transforms.Resize(img_size[num_model][1]),
            transforms.CenterCrop(img_size[num_model][1]),
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x[:3]),  # 只取前三个RGB通道
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
    elif img.mode == 'RGB':
        # RGB三通道图像,直接处理
        data_transform = transforms.Compose([
            transforms.Resize(img_size[num_model][1]),
            transforms.CenterCrop(img_size[num_model][1]),
            transforms.ToTensor(),
            transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
        ])
    else:
        raise ValueError("Unsupported image mode: {}".format(img.mode))

    # 应用转换
    img = data_transform(img)
    img = torch.unsqueeze(img, 0)  # 确保这里添加了批次维度

    # 读取类别索引文件
    json_path = '索引文件位置'
    assert os.path.exists(json_path), "file: '{}' does not exist.".format(json_path)
    with open(json_path, "r") as f:
        class_indict = json.load(f)

    # 英文到中文的映射
    english_to_chinese = {
        "3D": "3D",
        "Diascope": "玻璃",
        "Gloss": "镜面",
        "Luster": "金属",
        "Smooth": "平滑",
        "Texture": "纹理",
    }

    # 创建模型并加载权重
    model = create_model(num_classes=len(class_indict)).to(device)
    model_weight_path = "训练好的模型"
    model.load_state_dict(torch.load(model_weight_path, map_location=device))
    model.eval()

    # 进行预测
    with torch.no_grad():
        output = torch.squeeze(model(img.to(device))).cpu()
        predict = torch.softmax(output, dim=0)
        predict_cla = torch.argmax(predict).numpy()

    # 使用映射显示结果
    class_name_english = class_indict[str(predict_cla)]
    class_name_chinese = english_to_chinese.get(class_name_english, "未知类别")
    print_res = "类别: {}   概率: {:.3f}".format(class_name_chinese, predict[predict_cla].item())
    print(print_res)  # 打印最高预测结果和概率

if __name__ == '__main__':
    main()

对图片进行预测,预测结果如下(手机界面我都设置成了平滑)图片与预测结果


总结

以上就是今天代码所复现的内容,本文仅仅简单复现了EfficientNetV2的代码并训练预测,如有不足还望批评指正。

PS:修改的数据集位置分别为:
第131行训练集所在位置(train.py);
第132行原模型的位置(train.py);
第17行预测图片所在位置(predict.py);
第49行索引文件位置(predict.py);
第66行训练好的模型所在位置(predict.py)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值