数据来源:请点击我
代码:
from PIL import Image
from glob import glob
from torch.utils.data import Dataset
import torchvision.transforms as transforms
class InvalidDatasetException(Exception): # 如果照片路径总数不等于标签总数将执行这段代码
def __init__(self, len_of_paths, len_of_labels):
super().__init__(
f"Number of paths ({len_of_paths}) is not compatible with number of labels ({len_of_labels})"
)
transform = transforms.Compose([transforms.ToTensor()]) # 转换,transforms.Compose将几个变换组合在一起,这个转换不支持torchscript。
# transforms.ToTensor()转换一个PIL库的图片或者numpy的数组为tensor张量类型;转换从[0,255]->[0,1]
# 数据预处理
class AnimalDataset(Dataset): # 处理图片的类
def __init__(self, img_paths, img_labels, size_of_images): # 统一定义变量,img_paths:所有图片路径;img_labels:所有图片标签;size_of_images:图片大小
self.img_paths = img_paths
self.img_labels = img_labels
self.size_of_images = size_of_images
if len(self.img_paths) != len(self.img_labels): # 图片数量!=标签数量
raise InvalidDatasetException(self.img_paths, self.img_labels) # raise手动引发的异常
def __len__(self): # 返回图片总数
return len(self.img_paths)
def __getitem__(self, index): # 处理每一张图片
PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images) # 打开图片
TENSOR_IMAGE = transform(PIL_IMAGE) # 转换图片
label = self.img_labels[index] # 获取图片标签
return TENSOR_IMAGE, label # 返回转换后的图片和标签
# 初始化
paths = [] # 图片路径
labels = [] # 图片标签
label_map = {0: "Cat", # 标签
1: "Dog",
2: "Wild"}
# 提取图片路径
for cat_path in glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\train\cat\*.jpg") + glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\val\cat\*.jpg"):
paths.append(cat_path)
labels.append(0)
for dog_path in glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\train\dog\*.jpg") + glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\val\dog\*.jpg"):
paths.append(dog_path)
labels.append(1)
for wild_path in glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\train\wild\*.jpg") + glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\val\wild\*.jpg"):
paths.append(wild_path)
labels.append(2)
print(len(paths))
print(len(labels))
print(paths)
print(labels)
# 调用
dataset = AnimalDataset(paths, labels, (250, 250))