序言
本期博文介绍数据集处理和Dataset类封装,数据集采用的猫狗分类数据集,一个二分类数据集。本文围绕处理数据集的脚本构建,到Dataset类的构建进行讲解。对于模型搭建和训练则可以参考:
从零开始做图像分类任务(一)——构建图像分类模型(ResNet, VGG, ViT)
从零开始做图像分类任务(三)——训练和测试脚本(模型保存,断点恢复,Tensorboard,日志输出)
此外,数据集下载链接如下:
链接:https://pan.baidu.com/s/1_mvcB0Il63SKKF5MTBVt5w?pwd=pm07
提取码:pm07
1. 处理数据集
为了方便Dataset类的读取,我们构建一个csv文件存储数据集的图片路径和标签。
首先,在项目目录下新建脚本gen_label.py
import os
import csv
# 根据你自己的实际路径进行调整
train_dir = 'G:/Projects/Classfication/dogs-vs-cats/train'
# 获取文件夹下的所有图像文件
image_files = [f for f in os.listdir(train_dir) if f.endswith('.jpg')]
# 创建csv文件
csvfile = "label.csv"
with open(csvfile, "w", newline="") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(['Animal', 'Label'])
# 遍历图像文件并写入CSV文件
for image_file in image_files:
image_path = os.path.join(train_dir, image_file)
label = image_file.split('.')[0] # 使用文件名作为分类
id = 0 if label == 'cat' else 1
writer.writerow([image_path, id])
运行后,项目目录下会生成一个label.csv
文件:
2. Dataset构建
其次,是需要自定义供DataLoder读取的数据集类Dataset:
新建dataset.py
import torch
import pandas as pd
import os
from torch.utils.data import Dataset
from PIL import Image
class CatandDogDataset(Dataset):
def __init__(self, root_dir, transform=None):
"""
The is a dataset class, used to read image and label
:param root_dir: the csv file path
:param transform: if you want to use the data transform, transform is not None
"""
self.annotations = pd.read_csv(root_dir)
self.transform = transform
def __len__(self):
return len(self.annotations)
def __getitem__(self, index):
img_path = os.path.join(self.annotations.iloc[index, 0])
image = Image.open(img_path)
if self.transform is not None:
image = self.transform(image)
label = torch.tensor(int(self.annotations.iloc[index, 1]))
return image, label
完整代码详见:link 如果觉得还不错的话,欢迎fork和star。