Swin Transformer实战: timm使用、Mixup,附大厂真题面经

│ │ ├─Black-grass

│ │ ├─Charlock

│ │ ├─Cleavers

│ │ ├─Common Chickweed

│ │ ├─Common wheat

│ │ ├─Fat Hen

│ │ ├─Loose Silky-bent

│ │ ├─Maize

│ │ ├─Scentless Mayweed

│ │ ├─Shepherds Purse

│ │ ├─Small-flowered Cranesbill

│ │ └─Sugar beet

│ └─train

│ ├─Black-grass

│ ├─Charlock

│ ├─Cleavers

│ ├─Common Chickweed

│ ├─Common wheat

│ ├─Fat Hen

│ ├─Loose Silky-bent

│ ├─Maize

│ ├─Scentless Mayweed

│ ├─Shepherds Purse

│ ├─Small-flowered Cranesbill

│ └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob

import os

import shutil

image_list=glob.glob(‘data1//.png’)

print(image_list)

file_dir=‘data’

if os.path.exists(file_dir):

print(‘true’)

#os.rmdir(file_dir)

shutil.rmtree(file_dir)#删除再建立

os.makedirs(file_dir)

else:

os.makedirs(file_dir)

from sklearn.model_selection import train_test_split

trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)

train_dir=‘train’

val_dir=‘val’

train_root=os.path.join(file_dir,train_dir)

val_root=os.path.join(file_dir,val_dir)

for file in trainval_files:

file_class=file.replace(“\”,“/”).split(‘/’)[-2]

file_name=file.replace(“\”,“/”).split(‘/’)[-1]

file_class=os.path.join(train_root,file_class)

if not os.path.isdir(file_class):

os.makedirs(file_class)

shutil.copy(file, file_class + ‘/’ + file_name)

for file in val_files:

file_class=file.replace(“\”,“/”).split(‘/’)[-2]

file_name=file.replace(“\”,“/”).split(‘/’)[-1]

file_class=os.path.join(val_root,file_class)

if not os.path.isdir(file_class):

os.makedirs(file_class)

shutil.copy(file, file_class + ‘/’ + file_name)

训练

=============================================================

完成上面的步骤后,就开始train脚本的编写,新建train.py.

导入项目使用的库


import torch

import torch.nn as nn

import torch.nn.parallel

import torch.optim as optim

import torch.utils.data

import torch.utils.data.distributed

import torchvision.datasets as datasets

import torchvision.transforms as transforms

from sklearn.metrics import classification_report

from timm.data.mixup import Mixup

from timm.loss import SoftTargetCrossEntropy

from timm.models import swin_small_patch4_window7_224

from torchtoolbox.transform import Cutout

设置全局参数


设置学习率、BatchSize、epoch等参数,判断环境中是否存在GPU,如果没有则使用CPU。建议使用GPU,CPU太慢了。

设置全局参数

model_lr = 1e-4

BATCH_SIZE = 4

EPOCHS = 1000

DEVICE = torch.device(‘cuda:0’ if torch.cuda.is_available() else ‘cpu’)

图像预处理与增强


数据处理比较简单,加入了Cutout、做了Resize和归一化,定义Mixup函数。

数据预处理7

transform = transforms.Compose([

transforms.Resize((224, 224)),

Cutout(),

transforms.ToTensor(),

transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])

])

transform_test = transforms.Compose([

transforms.Resize((224, 224)),

transforms.ToTensor(),

transforms.Normalize(mean=[0.51819474, 0.5250407, 0.4945761], std=[0.24228974, 0.24347611, 0.2530049])

])

mixup_fn = Mixup(

mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,

prob=0.1, switch_prob=0.5, mode=‘batch’,

label_smoothing=0.1, num_classes=12)

读取数据


使用pytorch默认读取数据的方式,然后将dataset_train.class_to_idx打印出来,预测的时候要用到。

读取数据

dataset_train = datasets.ImageFolder(‘data/train’, transform=transform)

dataset_test = datasets.ImageFolder(“data/val”, transform=transform_test)

print(dataset_train.class_to_idx)

导入数据

train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)

class_to_idx的结果:

{‘Black-grass’: 0, ‘Charlock’: 1, ‘Cleavers’: 2, ‘Common Chickweed’: 3, ‘Common wheat’: 4, ‘Fat Hen’: 5, ‘Loose Silky-bent’: 6, ‘Maize’: 7, ‘Scentless Mayweed’: 8, ‘Shepherds Purse’: 9, ‘Small-flowered Cranesbill’: 10, ‘Sugar beet’: 11}

设置模型


  • 设置loss函数,train的loss为:SoftTargetCrossEntropy,val的loss:nn.CrossEntropyLoss()。

  • 设置模型为swin_small_patch4_window7_224,预训练设置为true,num_classes设置为12。

  • 检测可用显卡的数量,如果大于1,则要用torch.nn.DataParallel加载模型,开启多卡训练。

  • 优化器设置为adam。

  • 学习率调整策略选择为余弦退火。

实例化模型并且移动到GPU

criterion_train = SoftTargetCrossEntropy()

criterion_val = torch.nn.CrossEntropyLoss()

model_ft = swin_small_patch4_window7_224(pretrained=True)

print(model_ft)

num_ftrs = model_ft.head.in_features

model_ft.head = nn.Linear(num_ftrs, 12)

model_ft.to(DEVICE)

print(model_ft)

if torch.cuda.device_count() > 1:

print(“Let’s use”, torch.cuda.device_count(), “GPUs!”)

model_ft = torch.nn.DataParallel(model_ft)

print(model_ft)

选择简单暴力的Adam优化器,学习率调低

optimizer = optim.Adam(model_ft.parameters(), lr=model_lr)

cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-9)

定义训练和验证函数


定义训练函数和验证函数,在一个epoch完成后,使用classification_report计算详细的得分情况。

定义训练过程

def train(model, device, train_loader, optimizer, epoch):

model.train()

sum_loss = 0

total_num = len(train_loader.dataset)

print(total_num, len(train_loader))

for batch_idx, (data, target) in enumerate(train_loader):

data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)

samples, targets = mixup_fn(data, target)

optimizer.zero_grad()

output = model(data)

loss = criterion_train(output, targets)

loss.backward()

optimizer.step()

lr = optimizer.state_dict()[‘param_groups’][0][‘lr’]

print_loss = loss.data.item()

sum_loss += print_loss

if (batch_idx + 1) % 10 == 0:

print(‘Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}’.format(

epoch, (batch_idx + 1) * len(data), len(train_loader.dataset),

    • (batch_idx + 1) / len(train_loader), loss.item(), lr))

ave_loss = sum_loss / len(train_loader)

print(‘epoch:{},loss:{}’.format(epoch, ave_loss))

ACC = 0

验证过程

def val(model, device, test_loader):

global ACC

model.eval()

test_loss = 0

correct = 0

total_num = len(test_loader.dataset)

print(total_num, len(test_loader))

val_list = []

pred_list = []

with torch.no_grad():

for data, target in test_loader:

for t in target:

val_list.append(t.data.item())

data, target = data.to(device), target.to(device)

output = model(data)

loss = criterion_val(output, target)

_, pred = torch.max(output.data, 1)

for p in pred:

pred_list.append(p.data.item())

correct += torch.sum(pred == target)

print_loss = loss.data.item()

test_loss += print_loss

correct = correct.data.item()

acc = correct / total_num

avgloss = test_loss / len(test_loader)

print(‘\nVal set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n’.format(

avgloss, correct, len(test_loader.dataset), 100 * acc))

if acc > ACC:

torch.save(model_ft, ‘model_’ + str(epoch) + ‘_’ + str(round(acc, 3)) + ‘.pth’)

ACC = acc

return val_list, pred_list

训练

for epoch in range(1, EPOCHS + 1):

train(model_ft, DEVICE, train_loader, optimizer, epoch)

cosine_schedule.step()

val_list, pred_list = val(model_ft, DEVICE, test_loader)

print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))

运行结果:

image-20220222161855186

测试

=============================================================

我介绍两种常用的测试方式,第一种是通用的,通过自己手动加载数据集然后做预测,具体操作如下:

测试集存放的目录如下图:

image-20220205214047841

第一步 定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

第二步 定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

第三步 加载model,并将模型放在DEVICE里,

第四步 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。

import torch.utils.data.distributed

一、Python所有方向的学习路线

Python所有方向路线就是把Python常用的技术点做整理,形成各个领域的知识点汇总,它的用处就在于,你可以按照上面的知识点去找对应的学习资源,保证自己学得较为全面。

二、学习软件

工欲善其事必先利其器。学习Python常用的开发软件都在这里了,给大家节省了很多时间。

三、入门学习视频

我们在看视频学习的时候,不能光动眼动脑不动手,比较科学的学习方法是在理解之后运用它们,这时候练手项目就很适合了。

小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。

深知大多数初中级Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!

因此收集整理了一份《2024年Python爬虫全套学习资料》送给大家,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。

由于文件比较大,这里只是将部分目录截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频

如果你觉得这些内容对你有帮助,可以添加下面V无偿领取!(备注:python)
img

项目就很适合了。

小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。

深知大多数初中级Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!

因此收集整理了一份《2024年Python爬虫全套学习资料》送给大家,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。

由于文件比较大,这里只是将部分目录截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频

如果你觉得这些内容对你有帮助,可以添加下面V无偿领取!(备注:python)
[外链图片转存中…(img-82OSpfUx-1711203377652)]

  • 3
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值