timm 笔记:数据集

1 数据集channel数问题

1.1 torchvision的不足

        ImageNet数据由3通道RGB图像组成。因此,为了能够在大多数库中使用预先训练的权值,模型期望一个3通道的输入图像。

        比如对于resnet34,如果我们使用1个channel的输入的话:

import torch
import torchvision

m = torchvision.models.resnet34(pretrained=True)

x = torch.randn(1, 1, 224, 224)

try: m(x).shape
except Exception as e: print(e)
'''
Given groups=1, weight of size [64, 3, 7, 7], expected input[1, 1, 224, 224] to have 3 channels, but got 1 channels instead
'''

        是会报错的

        此时的一种方法是将1维的channel复制两次,成为三维的channel

import torch
import torchvision

m = torchvision.models.resnet34(pretrained=True)

x = torch.randn(1, 1, 224, 224)

x=torch.cat((x,x,x),1)# 新增了这一行

try: print(m(x).shape)
except Exception as e: print(e)
#torch.Size([1, 1000])

        然而,如果维度比3多的话,可能就没有办法删去某个维度,然后使用预训练模型。它们可以做的只是随机初始化权重,自己训练。

1.2 timm的解决方法

输入channel是1或者25都ok了

import timm

m = timm.create_model('resnet34', pretrained=True, in_chans=1)

x = torch.randn(1, 1, 224, 224)

m(x).shape

#torch.Size([1, 1000])
m = timm.create_model('resnet34', pretrained=True, in_chans=25)

# 25-channel image
x = torch.randn(1, 25, 224, 224)

m(x).shape
#torch.Size([1, 1000])

2 数据集Dataset

timm数据库中,有三种主要的数据集类:

  1. ImageDataset
  2. IterableImageDataset
  3. AugMixDataset

2.1 ImageDataset

  与torchvision.datasets.ImageFolder 类似,ImageDataset的作用是创建训练集和验证集

2.1.1 解析器 parser

        通过使用create_parser函数,我们可以自动设置解析器

        解析器找到所有root路径上的图片和目标

        root路径结构如下所示 

   解析器创建一个class_to_idx字典:

   

        同时有一个叫samples的元组列表:

        

        解析器是可以下标访问的, parser[index]将返回一个self.samples中标签是index的样本(比如parser[0],会返回一个('root/dog/xxx.png', 0)

 2.1.2 __getitem__(index: int) → Tuple[Any, Any]

一旦解析器创建完毕,那我们可以用以下方式获得图片和标签

img, target = self.parser[index]

然后将图像识别成PIL.Image,然后转换成RGB图像,还是读取成二进制,这取决于load_bytes语句

        如果图片没有target,那么我们将target设置为-1

2.1.3 使用场景

        ImageDataset也可以作为torchvision.datasets.ImageFolder的一个代替

        假设我们有imagenette2-320数据集,他的文件架构如下所示

数据集来源:

wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2-320.tgz

 每一个 n****都是一个文件夹,里面是属于这个类的JPEG文件

创建 ImageDataset:

from timm.data.dataset import ImageDataset

dataset = ImageDataset('./imagenette2-320')
dataset[0]
#(<PIL.Image.Image image mode=RGB size=426x320 at 0x22E890BF5C8>, 0)

dataset.parser

from timm.data.dataset import ImageDataset

dataset = ImageDataset('./imagenette2-320')
dataset.parser
#<timm.data.parsers.parser_image_folder.ParserImageFolder at 0x22e83cd0688>

 class_to_idx

from timm.data.dataset import ImageDataset

dataset = ImageDataset('./imagenette2-320')
dataset.parser.class_to_idx
'''
{'n01440764': 0,
 'n02102040': 1,
 'n02979186': 2,
 'n03000684': 3,
 'n03028079': 4,
 'n03394916': 5,
 'n03417042': 6,
 'n03425413': 7,
 'n03445777': 8,
 'n03888257': 9}
'''

 paser的sample

from timm.data.dataset import ImageDataset

dataset = ImageDataset('./imagenette2-320')
dataset.parser.samples[:5]
'''
[('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00000293.JPEG', 0),
 ('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00002138.JPEG', 0),
 ('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00003014.JPEG', 0),
 ('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00006697.JPEG', 0),
 ('./imagenette2-320\\train\\n01440764\\ILSVRC2012_val_00007197.JPEG', 0)]
'''

可视化一张数据的图片

import matplotlib.pyplot as plt # plt 用于显示图片
import matplotlib.image as mpimg # mpimg 用于读取图片
import numpy as np

lena = mpimg.imread(dataset.parser.samples[0][0]) # 读取和代码处于同一目录下的 lena.png
# 此时 lena 就已经是一个 np.array 了,可以对它进行任意处理
lena.shape #(512, 512, 3)

plt.imshow(lena) # 显示图片
plt.axis('off') # 不显示坐标轴
plt.show()

 

2.2  IterableImageDataset

        和pytorch的 IterableDataset 类似,timm提供了 IterableImageDataset。

  和ImageDataset相似,IterableImageDataset首先创建一个解析器,他也基于根目录创建一组样本。

        和ImageDataset相似,解析器也返回一组图像,图像的target也是图像所在的文件夹名称

   ***但有一点需要注意,IterableImageDataset并没有__getitem__方法,因此他不可以用下标访问。dataset[0]会报错

2.2.1 __iter__

       从IterableImageDataset的解析器中得到图片和对应的标签

from timm.data import IterableImageDataset
from timm.data.parsers.parser_image_folder import ParserImageFolder
from timm.data.transforms_factory import create_transform 

root = './imagenette2-320/'
parser = ParserImageFolder(root)
iterable_dataset = IterableImageDataset(root=root, parser=parser)
parser[0]
# (<_io.BufferedReader name='./imagenette2-320/train\\n01440764\\ILSVRC2012_val_00000293.JPEG'>,0)
next(iter(iterable_dataset))
# (<_io.BufferedReader name='./imagenette2-320/train\\n01440764\\ILSVRC2012_val_00000293.JPEG'>,0)

2.3 AugmixDataset

        augmix 是一种数据增强的方法

class AugmixDataset(
    dataset: ImageDataset, 
    num_splits: int = 2)

        最后的返回结果是 original data 和num_splits-1 轮的增强数据(每一轮增强数据都是原始数据的基础上获得的)

2.3.1   __getitem__(index: int) -> Tuple[Any, Any]

2.3.2 使用方法

这个需要GPU,所以我在服务器上跑的

>>> from timm.data import ImageDataset, IterableImageDataset, AugMixDataset, create_loader
>>>
>>> dataset = ImageDataset('./imagenette2-320/')
>>> dataset = AugMixDataset(dataset, num_splits=2)
>>> loader_train = create_loader(
...     dataset,
...     input_size=(3, 224, 224),
...     batch_size=8,
...     is_training=True,
...     scale=[0.08, 1.],
...     ratio=[0.75, 1.33],
...     num_aug_splits=2
... )
>>> next(iter(loader_train))[0].shape

torch.Size([16, 3, 224, 224])

注意看这里,我们的batch_size是8,返回的是16维,因为original是8,这里augmix又是8维

3 DataLoader

timm的 Dataloader比`torch.utils.data.DataLoader`快,且略有不同

创建timm的dataloader的最基本的方法就是调用timm.data.loader中的create_loader。它需要一个dataset对象,一个input_size和一个batch_size

3.1 创建dataset

 创建 ImageDataset:

from timm.data.dataset import ImageDataset

dataset = ImageDataset('./imagenette2-320')
dataset[0]
#(<PIL.Image.Image image mode=RGB size=426x320 at 0x22E890BF5C8>, 0)

3.2 创建DataLoader

from timm.data.loader import create_loader

try:
    # only works if gpu present on machine
    train_loader = create_loader(dataset, (3, 224, 224), 4)
except:
    train_loader = create_loader(dataset, (3, 224, 224), 4, use_prefetcher=False)

那么,这里为什么要用异常处理语句呢? 

3.2.1 Prefetch loader

        timm 有一个类PrefetchLoader。我们默认用这个DataLoader来创建我们的DataLoader。但是它只工作在GPU上。

        我本地的train_loader:

<torch.utils.data.dataloader.DataLoader at 0x22e834c3548>

        服务器(有GPU)的train_loader:

<timm.data.loader.PrefetchLoader object at 0x7f65acf9cef0>
 

  • 5
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

UQI-LIUWJ

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值