2.Datasets

构建自己的数据集

from torchvision import datasets
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torchvision import transforms
import os
from PIL import Image


# 构建自己的数据集
class RMBDataset(Dataset):
    def __init__(self, data, transform=None):
        """
        rmb面额分类任务的Dataset
        :param data_dir: str, 数据集所在路径
        :param transform: torch.transform,数据预处理
        """
        self.data = data  #
        self.transform = transform

    def __getitem__(self, index):
        img, label = self.data[index]  # 已经是PIL格式
        # img = Image.open(img).convert('RGB')     # 0~255

        if self.transform is not None:  # 在这个时候进行预处理,也是单张单张的处理
            img = self.transform(img)  # 在这里做transform,转为tensor等等

        return img, label

    def __len__(self):
        return len(self.data)


# 下载并载入数据
train_data = datasets.CIFAR10("./data", train=True, download=False)
test_data = datasets.CIFAR10("./data", train=False, download=False)
# print(test_data.classes)  # 打印标签属性
# image, label = test_data[0]
# print(type(image))
# print(len(train_data))

# # 数据预处理方法
data_transform = {
    "train": transforms.Compose([transforms.ToTensor(), transforms.RandomHorizontalFlip(0.5)]),  # 随机水平翻转
    "test": transforms.Compose([transforms.ToTensor()])
}

## 构建MyDataset实例
train_data = RMBDataset(data=train_data, transform=data_transform["train"])
valid_data = RMBDataset(data=test_data, transform=data_transform["test"])

# 使用DataLoader抽取数据,
train_loader = DataLoader(dataset=train_data, batch_size=4, shuffle=True, drop_last=True)
test_loader = DataLoader(dataset=test_data, batch_size=4, shuffle=True, drop_last=True)

# 迭代图片
for epoch in range(5):
    loss_mean = 0.
    correct = 0.
    total = 0.

    for i, data in enumerate(train_loader):
        # forward
        inputs, labels = data  # 会分别将image, label打包返回

DataLoader

数据加载器。结合数据集和采样器,并提供给定数据集的可迭代对象。 torch.utils.data.DataLoader支持单进程或多进程加载的地图样式和可迭代样式数据集,自定义加载顺序和可选的自动批处理(整理)和内存固定。

参数

 Args:
 		dataset:下载或载入的数据集
 		batch_size:(int, optional),batch_size的大小
 		shuffle:(bool, optional),是否随机打乱
 		sampler:(Sampler or Iterable, optional),定义从数据集中抽取样本的策略。可以是任何实现了 __len__ 的 ``Iterable``。如果指定,则不能指定 :attr:`shuffle`。
 		batch_sampler:
 		num_workers:(int, optional) 用于数据加载的子进程数。 ``0`` 表示数据将在主进程中加载。 (默认值:``0``)
 		collate_fn:
 		pin_memory(bool, optional): 如果为“True”,数据加载器将在返回之前将张量复制到 CUDA 固定内存中。如果您的数据元素是自定义类型,或者您的 collate_fn 返回一个自定义类型的批次,请参见下面的示例。
 		 drop_last (bool, optional): 如果数据集大小不能被批次大小整除,则设置为 True 以删除最后一个不完整的批次。如果 ``False`` 并且数据集的大小不能被批大小整除,那么最后一批将更小。 (默认:“假”)
 		 timeout (numeric, optional): 如果为正,则从工人那里收集批次的超时值。应始终为非负数。 (默认值:``0``)
 		 worker_init_fn (callable, optional):
 		 generator (torch.Generator, optional):
 		 prefetch_factor (int, optional, keyword-only arg):
 		 persistent_workers:
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值