【 ICCV代码复现】Swin Transformer图像分类实战教程 (训练自己的数据集)

本文详细介绍了如何在SwinTransformer中进行图像分类,包括环境配置(如使用pytorch和mmcv,以及CUDA和PyTorch版本),修改config.py、build.py和utils.py中的参数,以及训练和评估过程。还提供了处理常见错误如TypeError:init()gotanunexpectedkeywordargumentt_mul的方法。
摘要由CSDN通过智能技术生成

我用的是官方的代码,还有一位大神的集成代码也很不错,根据自己需求选择(不过选择大神的代码就不能看我这个教程了)https://github.com/WZMIAOMIAO/deep-learning-for-image-processing/tree/master/pytorch_classification/swin_transformer

论文地址:https://arxiv.org/pdf/2103.14030.pdf
GitHub地址:https://github.com/microsoft/Swin-Transformer/tree/main
在这里插入图片描述

一、环境配置

1.官方环境配置

基础pytorch、mmcv等,可以按照官方的教程如以下信息:
https://github.com/microsoft/Swin-Transformer/blob/main/get_started.md


我们推荐使用 pytorch docker nvcr>=21.05 by nvidia:
https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch
Clone this repo:

git clone https://github.com/microsoft/Swin-Transformer.git
cd Swin-Transformer

创建conda虚拟环境并激活:

conda create -n swin python=3.7 -y
conda activate swin

Install CUDA>=10.2 with cudnn>=7 following the official installation instructions
Install PyTorch>=1.8.0 and torchvision>=0.9.0 with CUDA>=10.2:

conda install pytorch==1.8.0 torchvision==0.9.0 cudatoolkit=10.2 -c pytorch

Install timm==0.4.12:

pip install timm==0.4.12

安装其他环境:

pip install opencv-python==4.4.0.46 termcolor==1.1.0 yacs==0.1.8 pyyaml scipy

Install fused window process for acceleration, activated by passing --fused_window_process in the running script

cd kernels/window_process
python setup.py install #--user

2.数据集结构

$ tree data
imagenet
├── train
│   ├── class1
│   │   ├── img1.jpeg
│   │   ├── img2.jpeg
│   │   └── ...
│   ├── class2
│   │   ├── img3.jpeg
│   │   └── ...
│   └── ...
└── val
    ├── class1
    │   ├── img4.jpeg
    │   ├── img5.jpeg
    │   └── ...
    ├── class2
    │   ├── img6.jpeg
    │   └── ...
    └── ...

二、修改配置等文件

1.修改config.py

_C.DATA.DATA_PATH = ‘dataset’
数据集路径的根目录,我定义为dataset,将数据集放在dataset里

_C.DATA.DATASET = ‘imagenet’
数据集的类型,这里只有一种类型imagenet

_C.MODEL.NUM_CLASSES:模型的类别,默认是1000,按照数据集的类别数量修改。

_C.SAVE_FREQ = 10 ,每多少个epoch保存一次模型

_C.TRAIN.EPOCHS = 300
训练300轮

2.修改build.py

找到mixup部分,将nb_classes =1000改为nb_classes = config.MODEL.NUM_CLASSES
修改完像下面这样
在这里插入图片描述

3.修改utils.py

找到load_checkpoint函数
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')后面插入

    if checkpoint['model']['head.weight'].shape[0] == 1000:
        checkpoint['model']['head.weight'] = torch.nn.Parameter(
            torch.nn.init.xavier_uniform(torch.empty(config.MODEL.NUM_CLASSES, 768)))
        checkpoint['model']['head.bias'] = torch.nn.Parameter(torch.randn(config.MODELNUM_CLASSES))

修改完如下所示
在这里插入图片描述

三、训练

1.Train

python -m torch.distributed.launch --nproc_per_node <num-of-gpus-to-use> --master_port 12345  main.py \ 
--cfg <config-file> --data-path <imagenet-path> [--batch-size <batch-size-per-gpu> --output <output-directory> --tag <job-tag>]

For example, to train Swin Transformer with 8 GPU on a single node for 300 epochs, run:

  • Swin-T:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
--cfg configs/swin/swin_tiny_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128 
  • Swin-S:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
--cfg configs/swin/swin_small_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 128 
  • Swin-B:
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345  main.py \
--cfg configs/swin/swin_base_patch4_window7_224.yaml --data-path <imagenet-path> --batch-size 64 \
--accumulation-steps 2 [--use-checkpoint]

2.Evaluation

python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.py --eval \
--cfg configs/swin/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path <imagenet-path>

nproc_per_node是GPU数量
config-file 是配置文件,在configs里

四、常见报错

1.TypeError: init() got an unexpected keyword argument ‘t_mul‘

删除Swin-Transformer/lr_scheduler.py的第24行‘t_mul=1.,’

评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值