Pytorch中DataLoader的collate_fn()参数学习笔记

1 Dataset和DataLoader创建和加载数据

使用pytorch训练网络之前的数据准备部分都要有两个流程:Dataset和DataLoader。前者用来定义自己的数据集类型(内部实现最主要的是__getitem__()方法);而后者则是负责真正在运行的时侯给网络递送数据。

1.1 Dataset类

继承Dataset类,自定义数据处理类。必须重载实现len()、getitem()这两个方法。
其中__len__返回数据集样本的数量,而__getitem__应该编写支持数据集索引的函数,返回数据和对应label,例如:通过dataset[i]可以得到数据集中的第i+1个数据。

1.2 DataLoader类

DataLoader完整的参数表如下:

class torch.utils.data.DataLoader(
 dataset,
 batch_size=1,
 shuffle=False,
 sampler=None,
 batch_sampler=None,
 num_workers=0,
 collate_fn=<function default_collate>,
 pin_memory=False,
 drop_last=False,
 timeout=0,
 worker_init_fn=None)

几个关键的参数意思:
dataset:PyTorch已有的数据读取接口或自定义数据接口的输出
batch_size:根据具体情况设置
shuffle:设置为True的时候,每个迭代都会打乱数据集,一般在训练数据中会采用
num_workers:这个参数必须大于等于0,0表示数据导入在主进程中进行,大于0表示通过多个进程来导入数据,可以加快数据导入速度。
collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
drop_last:告诉如何处理数据集长度除于batch_size余下的数据。True抛弃,否则保留

通常说来,我们在编写完Dataset之后,其内部的__getitem__会弹出一个[data, label]的一条数据,DataLoader中的collate_fn函数将这些一条一条的数据组织成一个batch。

注意:
通常的,默认的collate_fn函数是要求一个batch中的图片都具备相同size,当一个batch中的图片大小都不一样时(或者想要定值batch的输出形式),使用自定义的collate_fn函数。

1.3 自定义batch

通过collate_fn函数可以对这些样本做进一步的处理(任何你想要的处理),原则上返回值应当是一个有结构的batch。而DataLoader每次迭代的返回值就是collate_fn的返回值。

可以使用collate_fn的同时,结合使用默认的default_collate。

from torch.utils.data.dataloader import default_collate  # 导入这个函数
def my_collate_fn(batch):
    """
    params:
        batch :是一个列表,列表的长度是 batch_size,list中的每个元素都是__getitem__得到的结果。
               列表的每一个元素是 (x,y) 这样的元组tuple,元祖的两个元素分别是x,y
               大致的格式如下 [(x1,y1),(x2,y2),(x3,y3)...(xn,yn)]
    returns:
        整理之后的新的batch
    """
    # 这一部分是对 batch 进行重新 “校对、整理”的代码
    return default_collate(batch) #返回校对之后的batch,一般就直接推荐使用default_collate进行包装,因为它里面有很多功能,比如将numpy转化成tensor等操作,这是必须的。

然后调用时使用:

trainset = DataLoader(dataset=train_dataset,
                      batch_size=24,
                      shuffle=True,
                      collate_fn=my_collate_fn,

2 实例

图神经网络时,将多张图合并为一张大图。

"""
Combine multiple graphs into one large graph.
"""

#- collate_fn:如何取样本的,我们可以定义自己的函数来准确地实现想要的功能
# collate_fn这个函数的输入就是一个list,list的长度是一个batch size,list中的每个元素都是__getitem__得到的结果。
def collate_fn(batch):
    nodes_list = [b[0] for b in batch] #b[0]=p.array(nodes)
    nodes = np.concatenate(nodes_list, axis=0) #所有节点拼到一起,不扩展维度,拼成一个array
    #map 对于node_list每一组,计算shape(即节点个数)。也就是一个可迭代对象。返回array(每个图的节点个数)
    nodes_lens = np.fromiter(map(lambda l: l.shape[0], nodes_list), dtype=np.int64)
    nodes_inds = np.cumsum(nodes_lens) #计算一个数组各行的累加值
    nodes_num = nodes_inds[-1] #最后一个值,即总节点个数
    nodes_inds = np.insert(nodes_inds, 0, 0) #在第一个位置插入0这个值
    nodes_inds = np.delete(nodes_inds, -1) #按行展开后 删除最后一个元素

    edges_list = [b[1] for b in batch] #np.array(edges)
    edges_list = [e + i for e, i in zip(edges_list, nodes_inds)] #e是边的连接(每个图都从0开始) i是总节点个数
    edges = np.concatenate(edges_list, axis=0)
    m = edges_to_matrix(nodes_num, edges) #将每个batch拼接成一个邻接矩阵

    labels = [b[2] for b in batch] #np.array([float(label)])
    labels = np.concatenate(labels, axis=0)
    # batch中第i个图的节点个数为k batch_mask数组分别为[0,..,0] [1,..,1] 其中...为节点个数
    batch_mask = [np.array([i] * k, dtype=np.int32) for i, k in zip(range(len(batch)), nodes_lens)]
    batch_mask = np.concatenate(batch_mask, axis=0)

    #返回节点类型 邻接矩阵 预测值 batch_mask
    return torch.from_numpy(nodes), torch.from_numpy(m).float(), torch.from_numpy(labels), torch.from_numpy(batch_mask)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

来包番茄沙司

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值