Swin Transformer实战: timm使用、Mixup、Cutout和评分一网打尽,图像分类任务

本文介绍了如何使用Swin Transformer进行图像分类任务,包括数据预处理、Mixup和Cutout增强技术,以及训练和验证过程。通过计算mean和std值进行归一化,调整数据集格式,利用timm库和Mixup、Cutout优化模型,最终在训练和验证过程中应用SoftTargetCrossEntropy和CrossEntropyLoss损失函数,实现模型的训练和性能评估。
摘要由CSDN通过智能技术生成

=====================================================================

为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:

from torchvision.datasets import ImageFolder

import torch

from torchvision import transforms

def get_mean_and_std(train_data):

train_loader = torch.utils.data.DataLoader(

train_data, batch_size=1, shuffle=False, num_workers=0,

pin_memory=True)

mean = torch.zeros(3)

std = torch.zeros(3)

for X, _ in train_loader:

for d in range(3):

mean[d] += X[:, d, :, :].mean()

std[d] += X[:, d, :, :].std()

mean.div_(len(train_data))

std.div_(len(train_data))

return list(mean.numpy()), list(std.numpy())

if name == ‘main’:

train_dataset = ImageFolder(root=r’data1’, transform=transforms.ToTensor())

print(get_mean_and_std(train_dataset))

数据集结构:

image-20220221153058619

运行结果:

([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])

把这个结果记录下来,后面要用!

生成数据集

================================================================

我们整理还的图像分类的数据集结构是这样的

data

├─Black-grass

├─Charlock

├─Cleavers

├─Common Chickweed

├─Common wheat

├─Fat Hen

├─Loose Silky-bent

├─Maize

├─Scentless Mayweed

├─Shepherds Purse

├─Small-flowered Cranesbill

└─Sugar beet

pytorch和keras默认加载方式是ImageNet数据集格式,格式是

├─data

│ ├─val

│ │ ├─Black-grass

│ │ ├─Charlock

│ │ ├─Cleavers

│ │ ├─Common Chickweed

│ │ ├─Common wheat

│ │ ├─Fat Hen

│ │ ├─Loose Silky-bent

│ │ ├─Maize

│ │ ├─Scentless Mayweed

│ │ ├─Shepherds Purse

│ │ ├─Small-flowered Cranesbill

│ │ └─Sugar beet

│ └─train

│ ├─Black-grass

│ ├─Charlock

│ ├─Cleavers

│ ├─Common Chickweed

│ ├─Common wheat

│ ├─Fat Hen

│ ├─Loose Silky-bent

│ ├─Maize

│ ├─Scentless Mayweed

│ ├─Shepherds Purse

│ ├─Small-flowered Cranesbill

│ └─Sugar beet

新增格式转化脚本makedata.py,插入代码:

import glob

import os

import shutil

image_list=glob.glob(‘data1//.png’)

print(image_list)

file_dir=‘data’

if os.path.exists(file_dir):

pr

  • 18
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值