如何训练一个ViT模型-基于timm(2)

省流:训练(微调)一个自己的ViT模型

如果不想用预训练好的权重,想微调模型,或者干脆重新训练一个模型,该怎么办呢

脚本

官方提供了一些示例脚本,可以在github中下载

也可参考HuggingFace的文档,然后按照实列来运行

Scripts (huggingface.co)

比如

./distributed_train.sh 2 /imagenet -b 64 --model resnet50 --sched cosine --epochs 200 --lr 0.05 --amp --remode pixel --reprob 0.6 --aug-splits 3 --aa rand-m9-mstd0.5-inc1 --resplit --split-bn --jsd --dist-bn reduce

resnet50在imagenet上训练200epochs 余弦学习率 初始lr0.05 使用2张GPU分布式训练

自己写train

但是脚本上一大堆参数对于新手不是很友好,如果想快速上手体验一下的话,请重点阅读下一部分

接下来以最简单的代码,写一个微调ViT在Cifar10上训练的代码,来做示范

一个训练程序通常由以下几部分组成

1.导入包

2.解析参数

3.定义模型

4.准备数据

5.定义loss函数和optimzer

6.train loop

导入包

因为主要基于timm来展示,就只导入必要的torch和timm。tqdm是个可视化进度条的包

import torch

import timm

import tqdm

解析参数

对于新手,从命令行解析参数往往是一大难关,各种规则看不明白

那么我建议可以试试这样手动定义参数

#手动定义参数类,专门存储训练的参数,

class cfg(object):

pass

然后在后面随便写入自己想定义的参数,比如

args.epochs = 10#定义epochs

args.batch_size = 32 #定义batchsize

定义模型

模型导入预训练ViT,改变最后一层分类层,使之符合cifar10的10个类别

model = timm.create_model(args.model, pretrained=True,num_classes=10)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device) #可不可以用gpu,不行的话就与、用cpu

准备数据

数据准备分三步,dataset,transforms,dataloader

限于篇幅这边先不展开,感兴趣的同学可以先网上搜一下了解了解,之后也会讲这方面内容

dataset_train = create_dataset(args.dataset,root=args.data_dir,split='train',download=True)#导入dataset
​
train_transform = transforms.Compose([transforms.Resize(224),transforms.RandomHorizontalFlip(), transforms.RandomCrop(32, padding=4), transforms.ToTensor(),transforms.Normalize(mean, std)])
​
dataset_train.transform = train_transform#定义transforms
​
loader_train = create_loader(dataset_train, input_size=(3,224,224),batch_size=args.batch_size,use_prefetcher=args.prefetcher)
​
​

定义loss函数和optimzer

criterion = torch.nn.CrossEntropyLoss()

optimizer = torch.optim.SGD(model.head.parameters(),args.learning_rate)

优化器传入ViT head的参数,只训练后面的分类层

train loop

# 两层循环:

for epoch in num_epochs:

for input in batch_iters:

# 5句函数

outputs = model(input)

model.zero_grad()

loss = criterion(outputs, labels)

loss.backward()

optimizer.step()

我在自己电脑上简单跑了一下,大概2分钟一个epoch,还是挺快的,代码已经上传大家可以试一下

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值