Dataset和DataLoader用法
在d2l中有简洁的加载固定数据的方式,如下
d2l.load_data_fashion_mnist()
# 源码
Signature: d2l.load_data_fashion_mnist(batch_size, resize=None)
Source:
def load_data_fashion_mnist(batch_size, resize=None):
"""Download the Fashion-MNIST dataset and then load it into memory.
Defined in :numref:`sec_fashion_mnist`"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
File: ~/anaconda3/envs/d2l/lib/python3.9/site-packages/d2l/torch.py
Type: function
如果我们要自定义需要加载的数据集
数据集:一个图片文件夹,用csv文件来表示训练数据和标签
# 定义Dataset
import pandas as pd
import os
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
import torchvision.transforms as transforms
class CustomDataset(Dataset):
def __init__(self, csv_file, root_dir, transform=None):
self.data = pd.read_csv(csv_file)
self.root_dir = root_dir
self.transform = transform
label_encoder = LabelEncoder()
self.labels = label_encoder.fit_transform(self.data['label'])
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
img_name = os.path.join(self.root_dir, self.data.iloc[idx, 0])
# 读取图片并做增广
image = Image.open(img_name)
if self.transform is not None:
image = self.transform(image)
# 将数字转换成独热编码的张量(记得转换成float)
label = F.one_hot(torch.tensor(self.labels[idx]),
num_classes=self.data['label'].nunique()).float()
return image, label
# 定义参数和超参数训练
batch_size = 256
lr = num_epoch = 0.9, 10
# 加载数据
sample = '/kaggle/input/classify-leaves/sample_submission.csv'
ts_path = "/kaggle/input/classify-leaves/test.csv"
tr_path = "/kaggle/input/classify-leaves/train.csv"
image_path = '/kaggle/input/classify-leaves'
dataset = CustomDataset(csv_file = sample, root_dir = image_path, transform=transform_train)
train_size = int(0.8 * len(dataset))
valid_size = len(dataset) - train_size
tr_dataset, te_dataset = torch.utils.data.random_split(dataset, [train_size, valid_size])
tr_dataloader = DataLoader(tr_dataset, batch_size, shuffle=True)
ts_dataloader = DataLoader(te_dataset, batch_size, shuffle=False)
总结
需要将__init__,len,__getitem__按照数据集和模型的要求,对应的编写好代码。