pytorch的自定义数据集,和dataset中的图片展示

关于dataset, datalloader的使用,以及图像展示和相关参数说明

1. 图片展示

这是pytorch官网上的一个例子,获取到FashionMNIST的一个例子
[1]https://pytorch.org/tutorials/beginner/basics/data_tutorial.html

import torch
from torch.utils.data import Dataset
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt

# 获取dataset
training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()
)

test_data = datasets.FashionMNIST(
    root="data",
    train=False,
    download=True,
    transform=ToTensor()
)

# 随机从训练集中选出图像数据进行展示
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
figure = plt.figure(figsize=(8, 8))
cols, rows = 3, 3
for i in range(1, cols * rows + 1):
    sample_idx = torch.randint(len(training_data), size=(1,)).item()
    img, label = training_data[sample_idx]
    figure.add_subplot(rows, cols, i)
    plt.title(labels_map[label])
    plt.axis("off")
    plt.imshow(img.squeeze(), cmap="gray")
plt.show()

在这里插入图片描述

2.自定义数据集

import os
import pandas as pd
from torchvision.io import read_image

class CustomImageDataset(Dataset):
    def __init__(self, annotations_file, img_dir, transform=None, target_transform=None):
        self.img_labels = pd.read_csv(annotations_file)
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        img_path = os.path.join(self.img_dir, self.img_labels.iloc[idx, 0])
        image = read_image(img_path)
        label = self.img_labels.iloc[idx, 1]
        if self.transform:
            image = self.transform(image)
        if self.target_transform:
            label = self.target_transform(label)
        return image, label

3. Preparing your data for training with DataLoaders

2中的dataset只能一个一个的获取样本,然而在训练的时候往往需要多个训练样本组成 batch,以及对顺序进行打乱避免过拟合。

The DataLoader is an iterable that abstracts this complexity for us in an easy API. We use the Dataloader, we need to set the following paraments:

  1. data the training data that will be used to train the model; and test data to evaluate the model
  2. batch size the number of records to be processed in each batch
  3. shuffle the randoms sample of the data by indices

这样就生成了一个迭代器,每次获取的样本数量有batch_size决定,顺序是否打乱由shuffle参数决定

from torch.utils.data import DataLoader

train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

可以通过迭代器展示 样本:

# Display image and label.
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")

return

Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 5

4. DataLoaders中的一些参数

除了batch_size, shuffle 外,常常使用的参数还有

  1. num_workers (int, optional) 进程数量
    – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
  2. collate_fn参数的使用
    这个参数传入的是一个函数,这个函数主要是对每个batch进行处理,最终输出一个batch的返回值,换句话说collate_fn函数的返回值,就是遍历DataLoader的时候每个“batch”的返回值了。
	def mycollate(item):
	    train_features, train_labels = item
	    return {'feature':train_features,'lable':train_labels}
	from torch.utils.data import DataLoader
	myDataloader = DataLoader(dataset, shuffle=True, batch_size=8, collate_fn=mycollate)
这样再经过遍历dataloader的时候,返回的将是mycollate函数
	for batch in myDataloader:
	    print(batch)

得到的不再是train_features, train_labels
而是 {'feature':train_features, 'lable':train_labels}

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
PyTorch 自定义数据集可以通过继承 `torch.utils.data.Dataset` 类来实现。这个类需要实现两个方法:`__len__` 和 `__getitem__`。 `__len__` 方法返回数据集的长度,即样本数量。`__getitem__` 方法返回数据集一个索引对应的样本。 下面是一个简单的例子,假设我们有一个文件夹 `data`,里面包含若干张图片和对应的标签,我们要把这个数据集PyTorch 加载起来: ```python import os from PIL import Image import torch.utils.data as data class CustomDataset(data.Dataset): def __init__(self, root_dir): self.root_dir = root_dir self.img_list = os.listdir(root_dir) def __len__(self): return len(self.img_list) def __getitem__(self, index): img_path = os.path.join(self.root_dir, self.img_list[index]) img = Image.open(img_path).convert('RGB') label = int(self.img_list[index].split('_')[0]) return img, label ``` 在上面的例子,我们定义了一个 `CustomDataset` 类,它有一个构造函数 `__init__`,接收一个参数 `root_dir` 表示数据集所在的文件夹路径。`__init__` 方法初始化了 `img_list` 属性,里面保存了所有图片文件名。 `__len__` 方法返回了 `img_list` 的长度,即数据集样本的数量。 `__getitem__` 方法接收一个索引 `index`,返回了数据集第 `index` 个样本的图片和标签。具体地,它首先获取了图片文件的路径,然后用 `PIL` 库打开图片并转换成 RGB 模式。最后,它从文件名解析出标签信息,并把图片和标签一起返回。 有了这个自定义数据集类,我们就可以用 PyTorch 的 `DataLoader` 类来加载数据集了。例如: ```python import torch.utils.data as data dataset = CustomDataset('data') dataloader = data.DataLoader(dataset, batch_size=32, shuffle=True) ``` 在上面的例子,我们创建了一个 `CustomDataset` 对象 `dataset`,然后用 `DataLoader` 类来初始化 `dataloader` 对象。`DataLoader` 的第一个参数是数据集对象,第二个参数是批量大小,第三个参数是是否打乱数据集顺序。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值