windows11下运行swin-transformer算法

一、背景

我们希望使用swin-transformer算法实现物体的分类。 

swin-transformer的github地址为:https://github.com/microsoft/Swin-Transformer

本文参考:Swin Transformer实战:使用 Swin Transformer实现图像分类-阿里云开发者社区

二、环境配置

(1)配置要求

windows版本:windows11

pytorch版本:1.7.1

python版本:3.7.3(至少要大于3.6.2,因为pytorch1.7.1的python最低要求是3.6.2)

cuda版本:11.0(pytorch1.7.1在windows11下使用,最少需要cuda11.0)

以上配置为我试验swin-transformer运行的相对比较低的配置要求。

(2)安装方式

torch1.7.1的安装命令:pip install torch==1.7.1 -f https://download.pytorch.org/whl/torch_stable.html

torchvision的安装命令:pip install torchvision==0.8.2 -f https://download.pytorch.org/whl/torch_stable.html

cuda中官网下载toolkit即可,windows下cuda11.0.X版本可选择windows10对应的版本。

三、训练集构造

swin-transformer默认读取imagenet格式的数据集。

数据集的目录结构如下:

四、修改源码

1、修改config.py文件

_C.DATA.DATA_PATH = 'D:\\temp\\pic_ai\\swin_transformer_test'

_C.MODEL.NUM_CLASSES = 2

_C.DATA.NUM_WORKERS = 0

_C.DATA.PIN_MEMORY = False   

2、修改build.py文件

将nb_classes =1000改为nb_classes = config.MODEL.NUM_CLASSES,如下所示:

将部分_pil_interp修改为str_to_pil_interp,如下图所示:

 3、修改utils.py文件

由于类别默认是1000,所以加载模型的时候会出现类别对不上的问题,所以需要修改load_checkpoint方法。在加载预训练模型之前增加修改预训练模型的方法:

if checkpoint['model']['head.weight'].shape[0] == 1000:
    checkpoint['model']['head.weight'] = torch.nn.Parameter(
        torch.nn.init.xavier_uniform(torch.empty(config.MODEL.NUM_CLASSES, 768)))
    checkpoint['model']['head.bias'] = torch.nn.Parameter(torch.randn(config.MODELNUM_CLASSES))
msg = model.load_state_dict(checkpoint['model'], strict=False)

 

 4、修改main.py文件

(1)将如下代码注释:

(2)将torch.distributed.init_process_group修改为:

torch.distributed.init_process_group('gloo', init_method='file://tmp/somefile', rank=0, world_size=1)

该函数只有在pytorch1.7.1以上才支持。

5、修改lr_scheduler.py文件

将如下代码注释掉

五、运行训练命令:

python.exe D:/workspace/transformer/Swin-Transformer/main.py --cfg configs/swin/swin_tiny_patch4_window7_224.yaml --local_rank 0 --batch-size 2

执行后显示如下:

  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值