一、dataset和dataloader
二、查看dataset
from torch.utils.data import Dataset
help(Dataset)
所有的dataset都应该继承它。所有的子类都应该重写__getitem__方法,子类也可以选择性地重写:__len__方法。
class Dataset(typing.Generic)
| An abstract class representing a :class:Dataset
.
|
| All datasets that represent a map from keys to data samples should subclass
| it. All subclasses should overwrite :meth:__getitem__
, supporting fetching a
| data sample for a given key. Subclasses could also optionally overwrite
| :meth:__len__
, which is expected to return the size of the dataset by many
| :class:~torch.utils.data.Sampler
implementations and the default options
| of :class:~torch.utils.data.DataLoader
.
|
| … note::
| :class:~torch.utils.data.DataLoader
by default constructs a index
| sampler that yields integral indices. To make it work with a map-style
| dataset with non-integral indices/keys, a custom sampler must be provided.
三、dataset的制作代码
from torch.utils.data import Dataset
from PIL import Image
import os
class MyData(Dataset):
def __init__(self, root_dir, label_dir):
# self.xxx方法相当于让xxx成为了这个类的全局变量
self.root_dir = root_dir
self.label_dir = label_dir
self.path = os.path.join(root_dir, label_dir)
# 获取self.path下所有图片的地址,成一个列表,赋给self.img_path
self.img_path = os.listdir(self.path)
pass
# 获取图片地址(img_path)中的每一个图片
def __getitem__(self, idx):
# 通过每张图片地址的数字索引将图片的地址赋给img_name
img_name = self.img_path[idx]
# 将当前文件夹下的图片的相对地址赋给img_item_path
img_item_path = os.path.join(self.root_dir, self.label_dir, img_name)
# 使用Image模块打开图片
img = Image.open(img_item_path)
label = self.label_dir
return img, label
pass
def __len__(self):
return len(self.img_path)
# 创建一个ants的实例
root_dir = "hymenoptera_dataset/train"
ants_label_dir = "ants"
ants_dataset = MyData(root_dir, ants_label_dir)
# 创建一个bees的实例
bees_label_dir = "bees"
bees_dataset = MyData(root_dir, bees_label_dir)
在控制台进行了演示,查看变量。