作用:通过torch.utils.data.Dataset和torch.utils.data.DataLoader这两个类,使数据的读取变得非常简单,快捷。
torch.utils.data.Dataset
torch.utils.data.Dataset是代表自定义数据集方法的抽象类,你可以自己定义你的数据类继承这个抽象类,非常简单,只需要定义__len__和__getitem__这两个方法就可以。
当我们通过迭代的方式来取得每一个数据
class CustomDataset(Dataset):
"""默认使用 List 存储数据"""
def __init__(self, fp):
self.file = load_pkl(fp)#读取文件
def __getitem__(self, item):#item为索引
sample = self.file[item]
return sample
def __len__(self):
return len(self.file)
torch.utils.data.DataLoader
可以实现取batch,shuffle或者多线程读取数据
DataLoader(train_dataset, batch_size=cfg.batch_size, shuffle=True, collate_fn=collate_fn(cfg))
torch.utils.data.DataLoader类来定义一个新的迭代器,用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor,后续只需要再包装成Variable即可作为模型的输入。
collate_fn
可理解为函数句柄、指针…或者其他可调用类(实现__call__函数)。 函数对传入的batch数据进一步处理。具体如下:
def collate_fn(cfg):
def collate_fn_intra(batch):#对每一个batch进行处理
batch.sort(key=lambda data: data['seq_len'], reverse=True)
max_len = batch[0]['seq_len']
# 对数据集进行padding,
def _padding(x, max_len):
return x + [0] * (max_len - len(x))
x, y = dict(), []
word, word_len = [], []
head_pos, tail_pos = [], []
pcnn_mask = []
for data in batch:
word.append(_padding(data['token2idx'], max_len))
word_len.append(data['seq_len'])
y.append(int(data['rel2idx']))
if cfg.model_name != 'lm':
head_pos.append(_padding(data['head_pos'], max_len))
tail_pos.append(_padding(data['tail_pos'], max_len))
if cfg.model_name == 'cnn':
if cfg.use_pcnn:
pcnn_mask.append(_padding(data['entities_pos'], max_len))
x['word'] = torch.tensor(word)
x['lens'] = torch.tensor(word_len)
y = torch.tensor(y)
if cfg.model_name != 'lm':
#转换为torch。tensor
x['head_pos'] = torch.tensor(head_pos)
x['tail_pos'] = torch.tensor(tail_pos)
if cfg.model_name == 'cnn' and cfg.use_pcnn:
x['pcnn_mask'] = torch.tensor(pcnn_mask)
if cfg.model_name == 'gcn':
# 没找到合适的做 parsing tree 的工具,暂时随机初始化
B, L = len(batch), max_len
adj = torch.empty(B, L, L).random_(2)
x['adj'] = adj
return x, y
return collate_fn_intra
抽象方法
__len__(self)
定义当被len()函数调用时的行为(返回容器中元素的个数)
__getitem__(self)
定义获取容器中指定元素的行为,相当于self[key],即允许类对象可以有索引操作。
__iter__(self)
定义当迭代容器中的元素的行为