当现有的用于分类图像数据只有如下的tree路径:
./车辆分类数据集
Lcar
L*.jpg
Lbus
L*.jpg
Ltruck
L*.jpg
如果这是你全部的数据,并且在使用dataloader前你需要根据根目录(即:./车辆分类数据集)来创建训练用的ImageFolder以及验证用的ImageFolder(或者增加了测试用的ImageFolder)。此时,你不需要手动、或者用OS之类命令的代码来划分数据集,只需要重写(覆写)ImageFolder即可。
直接放代码。
# -*- coding: utf-8 -*-
"""
@Time : 2022/8/14 10:09
@Auth : Fanteng Meng
@File :imgfolder.py
@IDE :PyCharm
@Github : https://github.com/FT115
"""
import torch
import random
import torchvision.transforms as transforms
import torchvision.datasets as datasets
normalize = transforms.Normalize(mean=[.5, .5, .5],
std=[.5, .5, .5])
train_transform = transforms.Compose([])
train_transform.transforms.append(transforms.Resize((224, 224)))
train_transform.transforms.append(transforms.ToTensor())
train_transform.transforms.append(transforms.RandomHorizontalFlip(p=0.8))
train_transform.transforms.append(normalize)
val_transform = transforms.Compose([])
val_transform.transforms.append(transforms.Resize((224, 224)))
val_transform.transforms.append(transforms.ToTensor())
val_transform.transforms.append(normalize)
class CustomImageFolder(datasets.ImageFolder):
def __init__(self, root, transform, mode, train_ratio):
super(CustomImageFolder, self).__init__(root, transform)
assert mode in ['train', 'val']
random.seed(0)
random.shuffle(self.samples)
if mode == 'train':
self.samples = self.samples[:int(train_ratio*len(self))]
self.targets = [s[1] for s in self.samples]
self.imgs = self.samples
elif mode == 'val':
self.samples = self.samples[int(train_ratio*len(self)):]
self.targets = [s[1] for s in self.samples]
self.imgs = self.samples
def __getitem__(self, index):
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target
def __len__(self) -> int:
return len(self.samples)
batch_size = 6
data_path = './实验三数据集/车辆分类数据集'
train_dataset = CustomImageFolder(data_path, train_transform, 'train', 0.7)
val_dataset = CustomImageFolder(data_path, val_transform, 'val', 0.7)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
下面进行分开解释代码中的部分功能。
1.对于所有数据要打乱标签
这是因为原始的self.samples是按照类别的所有图片顺序排列,不打乱标签会使得划分的前百分之的数据类别严重不均衡,打乱标签会变得让数据更均衡。增加种子随机,可以让每次随机打乱的结果始终一致。对应代码中的这一部分:
random.seed(0)
random.shuffle(self.samples)
2.添加了mode、train_ratio参数
用于设置时训练集,还是验证集(如果要增加测试集,可自行添加)。以百分比率为train_ratio划分训练与验证集,前train_ratio作为训练集。只需对打乱的samples进行取相应量的数据即可,self.target按照ImageFolder封装好的官方实现方式重新写即可,即每个samples的索引1的值。self.imgs与self.samples相等,这是封装内部的官方实现,这里直接写就好。对应代码在这:
if mode == 'train':
self.samples = self.samples[:int(train_ratio*len(self))]
self.targets = [s[1] for s in self.samples]
self.imgs = self.samples
elif mode == 'val':
self.samples = self.samples[int(train_ratio*len(self)):]
self.targets = [s[1] for s in self.samples]
self.imgs = self.samples
def __getitem__(self, index)与def __len__(self) 保持不变,直接复制过来。因为不变,这里也可以不写。
记录自己学到东西的同时,希望对你有所帮助~