深度学习——pytorch自制、定义数据集

数据来源:请点击我

代码:

from PIL import Image
from glob import glob
from torch.utils.data import Dataset
import torchvision.transforms as transforms


class InvalidDatasetException(Exception):   # 如果照片路径总数不等于标签总数将执行这段代码

    def __init__(self, len_of_paths, len_of_labels):  
        super().__init__(
            f"Number of paths ({len_of_paths}) is not compatible with number of labels ({len_of_labels})"
        )


transform = transforms.Compose([transforms.ToTensor()])  # 转换,transforms.Compose将几个变换组合在一起,这个转换不支持torchscript。
# transforms.ToTensor()转换一个PIL库的图片或者numpy的数组为tensor张量类型;转换从[0,255]->[0,1]


# 数据预处理
class AnimalDataset(Dataset):  # 处理图片的类
    def __init__(self, img_paths, img_labels, size_of_images):  # 统一定义变量,img_paths:所有图片路径;img_labels:所有图片标签;size_of_images:图片大小
        self.img_paths = img_paths
        self.img_labels = img_labels
        self.size_of_images = size_of_images
        if len(self.img_paths) != len(self.img_labels):  # 图片数量!=标签数量
            raise InvalidDatasetException(self.img_paths, self.img_labels)  # raise手动引发的异常

    def __len__(self):  # 返回图片总数
        return len(self.img_paths)

    def __getitem__(self, index):  # 处理每一张图片
        PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images)  # 打开图片
        TENSOR_IMAGE = transform(PIL_IMAGE)  # 转换图片
        label = self.img_labels[index]  # 获取图片标签
        return TENSOR_IMAGE, label  # 返回转换后的图片和标签


# 初始化
paths = []  # 图片路径
labels = []  # 图片标签
label_map = {0: "Cat",  # 标签
             1: "Dog",
             2: "Wild"}

# 提取图片路径
for cat_path in glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\train\cat\*.jpg") + glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\val\cat\*.jpg"):
    paths.append(cat_path)
    labels.append(0)

for dog_path in glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\train\dog\*.jpg") + glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\val\dog\*.jpg"):
    paths.append(dog_path)
    labels.append(1)

for wild_path in glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\train\wild\*.jpg") + glob(r"C:\Users\AIAXIT\Desktop\DeepLearningProject\Project\Animal Faces\afhq\val\wild\*.jpg"):
    paths.append(wild_path)
    labels.append(2)

print(len(paths))
print(len(labels))
print(paths)
print(labels)

# 调用
dataset = AnimalDataset(paths, labels, (250, 250))

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值