参考:
https://blog.csdn.net/zw__chen/article/details/82806900
https://blog.csdn.net/Threelights/article/details/88680540
一 要点总结
- 1 torch.utils.data.Dataset是一个包装类,用来将数据包装为Dataset类,然后传入DataLoader中。
- 2 将数据包装为Dataset类有两种方式:一种是用 torch.utils.data.Dataset.TensorDataset 来将数据包装成Dataset类,一种是自己写一个继承 torch.utils.data.Dataset类 ,实现类中的 len 方法和getitem 方法。
- 3 torchvision.datasets中的所有数据集都是 torch.utils.data.dataset 的子类,因而都实现了 getitem 和 len 方法,可以直接传递给 torch.utils.data.dataloader
二 示例
1 用 torch.utils.data.Dataset.TensorDataset 来将数据包装成Dataset类
import h5py
import numpy as np
from torch.utils.data import DataLoader, TensorDataset, Dataset
train_dataset = h5py.File('datasets/train_signs.h5', "r")
train_dataset
train_dataset.keys()
list_classes,train_set_x,train_set_y = train_dataset['list_classes'],train_dataset['train_set_x'],train_dataset['train_set_y']
list_classes, train_set_x, train_set_y
x_train = np.array(train_dataset["train_set_x"][:])
x_train = np.transpose(x_train, (0, 3, 1, 2))
y_train = np.array(train_dataset["train_set_y"][:])
y_train = y_train.reshape((1, y_train.shape[0])).T
X_train_tensor = torch.tensor(x_train, dtype=torch.float)/255
Y_train_tensor = torch.tensor(y_train, dtype=torch.long)
train_dataset = TensorDataset(X_train_tensor, Y_train_tensor)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
for epoch in range(2):
for i,data in enumerate(train_loader):
inputs,labels = data
print(epoch,i,"inputs",inputs.size(),"labels",labels.size())
![在这里插入图片描述](https://img-blog.csdnimg.cn/20210620125535166.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0xJV0VJOTQwNjM4MDkz,size_16,color_FFFFFF,t_70)
2 继承 torch.utils.data.Dataset类 ,实现类中的 len 方法和getitem 方法
class TrainDataset(Dataset):
"""
下载数据、初始化数据,都可以在这里完成
"""
def __init__(self):
train_dataset = h5py.File('datasets/train_signs.h5', "r")
x_train,y_train = np.array(train_dataset["train_set_x"][:]),np.array(train_dataset["train_set_y"][:])
x_train,y_train = np.transpose(x_train, (0, 3, 1, 2)),y_train.reshape((1, y_train.shape[0])).T
X_train_tensor,Y_train_tensor = torch.tensor(x_train, dtype=torch.float)/255,torch.tensor(y_train, dtype=torch.long)
self.x_data = X_train_tensor
self.y_data = Y_train_tensor
def __getitem__(self, index):
return self.x_data[index], self.y_data[index]
def __len__(self):
return x_train.shape[0]
trainDataset = TrainDataset()
type(trainDataset)
len(trainDataset)
train_loader2 = DataLoader(dataset=trainDataset,
batch_size=64,
shuffle=True)
for epoch in range(2):
for i, data in enumerate(train_loader2):
inputs, labels = data
print("epoch:", epoch, "的第" , i, "个inputs", inputs.size(), "labels", labels.size())
![在这里插入图片描述](https://img-blog.csdnimg.cn/20210620130638738.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0xJV0VJOTQwNjM4MDkz,size_16,color_FFFFFF,t_70)
3 torchvision.datasets.mnist使用示例
from torchvision.datasets import mnist
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
train_dataset = mnist.MNIST(root='./train', train=True)
train_dataset
type(train_dataset)
len(train_dataset)
train_dataset[0]
train_dataset[0][1]
train_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
for epoch in range(2):
for i, data in enumerate(train_loader):
inputs, labels = data
print("epoch:", epoch, "的第" , i, "个inputs", inputs.size(), "labels", labels.size())
train_dataset = mnist.MNIST(root='./train', train=True,transform=ToTensor())
train_dataset
type(train_dataset)
len(train_dataset)
train_dataset[0]
![在这里插入图片描述](https://img-blog.csdnimg.cn/20210620133247927.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0xJV0VJOTQwNjM4MDkz,size_16,color_FFFFFF,t_70)
train_loader = DataLoader(dataset=train_dataset,batch_size=64,shuffle=True)
for epoch in range(2):
for i, data in enumerate(train_loader):
inputs, labels = data
print("epoch:", epoch, "的第" , i, "个inputs", inputs.size(), "labels", labels.size())
![在这里插入图片描述](https://img-blog.csdnimg.cn/20210620133409763.png?x-oss-process=image/watermark,type_ZmFuZ3poZW5naGVpdGk,shadow_10,text_aHR0cHM6Ly9ibG9nLmNzZG4ubmV0L0xJV0VJOTQwNjM4MDkz,size_16,color_FFFFFF,t_70)