创建自己的数据处理类(图像分类——ImageFolder)

代码:

Imagefolder :

'''训练模型'''
import os  

os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import time  
from thop import profile  # 用于计算模型的参数数量和FLOPs
import matplotlib.pyplot as plt  # 画图
import torch
from torch.optim import lr_scheduler  # 优化器
from torch.utils.data import DataLoader
# 处理数据集的库
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm  # 进度条
from net import resnet50,resnet101,MyAlexNet  # 导入写好的网络模型

start = time.time()

# 解决中文现实问题
plt.rcParams["font.sans-serif"] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

'''1、数据处理'''
ROOT_train = r'Z:\meter_classify\datasets\xldlb'
ROOT_test = r"Z:\meter_classify\datasets\xldlb"

# 把图像做数据处理
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),  # 把所有图像统一定义成一个大小
    transforms.RandomVerticalFlip(),  # 随机垂直旋转,做数据增强
    transforms.ToTensor(),  # 把图片转换为张量数据
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 归一化
])
# 测试集也需要做同样的操作,但是不需要做数据增强
val_transforms = transforms.Compose([
    transforms.Resize((224, 224)), 
    transforms.ToTensor(), 
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# 数据最终处理
train_dataset = ImageFolder(ROOT_train, transform=train_transforms)
val_dataset = ImageFolder(ROOT_test, transform=val_transforms)

# 把数据分批次bacth,shuffle=True 打乱
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)


'''2、导入模型'''
'''3、定义损失和优化'''
'''4、训练、保存模型'''

Imagefolder class :

class ImageFolder(DatasetFolder):
    """A generic data loader where the images are arranged in this way by default: ::

        root/dog/xxx.png
        root/dog/xxy.png
        root/dog/[...]/xxz.png

        root/cat/123.png
        root/cat/nsdf3.png
        root/cat/[...]/asd932_.png

    This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
    the same methods can be overridden to customize the dataset.

    Args:
        root (string): Root directory path.
        transform (callable, optional): A function/transform that  takes in an PIL image
            and returns a transformed version. E.g, ``transforms.RandomCrop``
        target_transform (callable, optional): A function/transform that takes in the
            target and transforms it.
        loader (callable, optional): A function to load an image given its path.
        is_valid_file (callable, optional): A function that takes path of an Image file
            and check if the file is a valid file (used to check of corrupt files)

     Attributes:
        classes (list): List of the class names sorted alphabetically.
        class_to_idx (dict): Dict with items (class_name, class_index).
        imgs (list): List of (image path, class_index) tuples
    """

    def __init__(
        self,
        root: str,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        loader: Callable[[str], Any] = default_loader,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ):
        super().__init__(
            root,
            loader,
            IMG_EXTENSIONS if is_valid_file is None else None,
            transform=transform,
            target_transform=target_transform,
            is_valid_file=is_valid_file,
        )
        self.imgs = self.samples

Imagefolder 的父类DatasetFolder(重要) :

class DatasetFolder(VisionDataset):
    def __init__(
        self,
        root: str,
        loader: Callable[[str], Any],
        extensions: Optional[Tuple[str, ...]] = None,
        transform: Optional[Callable] = None,
        target_transform: Optional[Callable] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> None:
        super().__init__(root, transform=transform, target_transform=target_transform)
        classes, class_to_idx = self.find_classes(self.root)
        samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)

        self.loader = loader   # pil.open() or accimage()
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = class_to_idx
        self.samples = samples
        self.targets = [s[1] for s in samples]

    @staticmethod
    def make_dataset(
        directory: str,
        class_to_idx: Dict[str, int],
        extensions: Optional[Tuple[str, ...]] = None,
        is_valid_file: Optional[Callable[[str], bool]] = None,
    ) -> List[Tuple[str, int]]:
       
        if class_to_idx is None:
          is_valid_file=is_valid_file)

    def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]: 
        return find_classes(directory)

    def __getitem__(self, index: int) -> Tuple[Any, Any]:
        """
        Args:
            index (int): Index
        Returns:
            tuple: (sample, target) where sample is one pre-processed image,target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        if self.target_transform is not None:
            target = self.target_transform(target)
        return sample, target   # (img, label)

    def __len__(self) -> int:
        return len(self.samples)
       


分析说明:

1.ImageFolder类 负责读取一张图片,并对图像进行预处理(transform),然后返回预处理后的图片以及对应的标签(该图片的所属类的索引);
2. ImageFolder类 调用 DatasetFolder类DatasetFolder类__getitem__(self, index: int)方法返回 预处理后的一张图片和target(该图片的所属类的索引);
3. Dataloader连续访问DatasetFolder类__getitem__(self, index: int)方法,逐个获得全部的预处理后的图片和target(该图片的所属类的索引),然后按照batchsize对结果进行拼接;


  • 10
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值