代码:
Imagefolder :
'''训练模型'''
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import time
from thop import profile # 用于计算模型的参数数量和FLOPs
import matplotlib.pyplot as plt # 画图
import torch
from torch.optim import lr_scheduler # 优化器
from torch.utils.data import DataLoader
# 处理数据集的库
from torchvision import transforms
from torchvision.datasets import ImageFolder
from tqdm import tqdm # 进度条
from net import resnet50,resnet101,MyAlexNet # 导入写好的网络模型
start = time.time()
# 解决中文现实问题
plt.rcParams["font.sans-serif"] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
'''1、数据处理'''
ROOT_train = r'Z:\meter_classify\datasets\xldlb'
ROOT_test = r"Z:\meter_classify\datasets\xldlb"
# 把图像做数据处理
train_transforms = transforms.Compose([
transforms.Resize((224, 224)), # 把所有图像统一定义成一个大小
transforms.RandomVerticalFlip(), # 随机垂直旋转,做数据增强
transforms.ToTensor(), # 把图片转换为张量数据
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 归一化
])
# 测试集也需要做同样的操作,但是不需要做数据增强
val_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
# 数据最终处理
train_dataset = ImageFolder(ROOT_train, transform=train_transforms)
val_dataset = ImageFolder(ROOT_test, transform=val_transforms)
# 把数据分批次bacth,shuffle=True 打乱
train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)
'''2、导入模型'''
'''3、定义损失和优化'''
'''4、训练、保存模型'''
Imagefolder class :
class ImageFolder(DatasetFolder):
"""A generic data loader where the images are arranged in this way by default: ::
root/dog/xxx.png
root/dog/xxy.png
root/dog/[...]/xxz.png
root/cat/123.png
root/cat/nsdf3.png
root/cat/[...]/asd932_.png
This class inherits from :class:`~torchvision.datasets.DatasetFolder` so
the same methods can be overridden to customize the dataset.
Args:
root (string): Root directory path.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
loader (callable, optional): A function to load an image given its path.
is_valid_file (callable, optional): A function that takes path of an Image file
and check if the file is a valid file (used to check of corrupt files)
Attributes:
classes (list): List of the class names sorted alphabetically.
class_to_idx (dict): Dict with items (class_name, class_index).
imgs (list): List of (image path, class_index) tuples
"""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
loader: Callable[[str], Any] = default_loader,
is_valid_file: Optional[Callable[[str], bool]] = None,
):
super().__init__(
root,
loader,
IMG_EXTENSIONS if is_valid_file is None else None,
transform=transform,
target_transform=target_transform,
is_valid_file=is_valid_file,
)
self.imgs = self.samples
Imagefolder 的父类DatasetFolder(重要) :
class DatasetFolder(VisionDataset):
def __init__(
self,
root: str,
loader: Callable[[str], Any],
extensions: Optional[Tuple[str, ...]] = None,
transform: Optional[Callable] = None,
target_transform: Optional[Callable] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> None:
super().__init__(root, transform=transform, target_transform=target_transform)
classes, class_to_idx = self.find_classes(self.root)
samples = self.make_dataset(self.root, class_to_idx, extensions, is_valid_file)
self.loader = loader # pil.open() or accimage()
self.extensions = extensions
self.classes = classes
self.class_to_idx = class_to_idx
self.samples = samples
self.targets = [s[1] for s in samples]
@staticmethod
def make_dataset(
directory: str,
class_to_idx: Dict[str, int],
extensions: Optional[Tuple[str, ...]] = None,
is_valid_file: Optional[Callable[[str], bool]] = None,
) -> List[Tuple[str, int]]:
if class_to_idx is None:
is_valid_file=is_valid_file)
def find_classes(self, directory: str) -> Tuple[List[str], Dict[str, int]]:
return find_classes(directory)
def __getitem__(self, index: int) -> Tuple[Any, Any]:
"""
Args:
index (int): Index
Returns:
tuple: (sample, target) where sample is one pre-processed image,target is class_index of the target class.
"""
path, target = self.samples[index]
sample = self.loader(path)
if self.transform is not None:
sample = self.transform(sample)
if self.target_transform is not None:
target = self.target_transform(target)
return sample, target # (img, label)
def __len__(self) -> int:
return len(self.samples)
分析说明:
1.
ImageFolder类
负责读取一张图片,并对图像进行预处理(transform)
,然后返回预处理后的图片以及对应的标签(该图片的所属类的索引);
2.ImageFolder类
调用DatasetFolder类
,DatasetFolder类
的__getitem__(self, index: int)
方法返回 预处理后的一张图片和target(该图片的所属类的索引);
3.Dataloader
连续访问DatasetFolder类
的__getitem__(self, index: int)
方法,逐个获得全部的预处理后的图片和target(该图片的所属类的索引),然后按照batchsize
对结果进行拼接;