自定义DataLoader
这是为了补上次讲自定义Dataset挖的坑
连接如果想看前文 链接: 自定义Dataset
前情提要
由于DataLoader的进行是需要在Dataset的基础上,这是Dataset的基本结构
class MyDataset(Dataset):
def __init__(self, datas, label_list):
self.datas = datas
self.labels = []
self.label_list = label_list
keys = list(set([y for y in self.label_list]))
keys.sort()
dictkeys = {key: ii for ii, key in enumerate(keys)}
print(dictkeys)
for i in range(len(self.label_list)):
self.labels.append(dictkeys[label_list[i]])
def __len__(self):
return len(self.labels)
def __getitem__(self, idx):
return torch.FloatTensor(self.datas[idx]), torch.tensor(self.labels[idx])
数据准备,
data
的shape为(26, 10),第一行全为1,第二行全为2,一直到最后一行全为26。label
为26个字母
import numpy as np
# 创建一个 shape 为 (24, 10) 的数组
array_shape = (24, 10)
# 使用 np.arange() 函数生成从 1 到 24 的数组
data_array = np.arange(1, array_shape[0] + 1)
# 使用 np.tile() 函数将每个元素复制 10 次,构成行
result_array = np.tile(data_array, (array_shape[1], 1)).T
print(result_array)
# 使用列表推导式生成从 A 到 Z 的字母列表
alphabet_list = [chr(i) for i in range(ord('A'), ord('Z') + 1)]
print(alphabet_list)
到重点了
其实就是继承DataLoader类,然后将自己写的方法替代其中的self.collate_fn方法
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
def _collate_fn(batch):
datas = []
labels = []
for i in range(len(batch)):
datas.append(batch[i][0])
labels.append(batch[i][1])
datas = torch.stack(datas, dim=0)
labels = torch.stack(labels)
return datas, labels
class MyDataLoader(DataLoader):
def __init__(self, *args, **kwargs):
super(MyDataLoader, self).__init__(*args, **kwargs)
self.collate_fn = _collate_fn
至于self.collate_fn
起什么作用,我们都知道DataLoader中有一个参数batch_size
,假如我们将batch_size设置为4,那么你可以理解为,调用Dataset中的__getitem__(self, idx)
4次
然后将__getitem__(self, idx)得到的返回值设为元组,再将四个元组添加到一个list中,也就变成我们所说的一个batch, 这个batch便会作为参数传给self.collate_fn函数
其中还有一个参数shuffle
,如果设置为True,那么就会随机的设置__getitem__(self, idx)的idx, 设置为False,按顺序取idx
最后就可以使用了
train_dataset = MyDataset(result_array, alphabet_list)
train_dataloader = MyDataLoader(train_dataset, batch_size=4, shuffle=False)
for idx, (data, label) in enumerate(train_dataloader):
print(label)