pytorch collaten_fn函数取出数值操作

collate_fn函数可以在取出数值的时候对每次进行前面补充0的操作,很方便每一个批次对于数据进行操作
首先定义对应的MyDataset类

import torch
from torch.utils.data import Dataset, DataLoader
# 1. 用Dataset封装数据集,仅做示范,实际可直接用TensorDataset封装
class MyDataset(Dataset):
    def __init__(self, x, y):
        #assert x.size(0)==y.size(0)
        self.x, self.y = x, y
    #定义初始化变量
    def __getitem__(self, idx):
        #print('__getitem__')
        return (self.x[idx], self.y[idx])
    #定义每次取出的对应数值
    def __len__(self):
        return len(self.x)
    #定义tensor的总长度
# 2. 用DataLoader定义数据批量迭代器

接下来定义collate_fn函数为每次取出的数值的操作,这里之所以collate_fn函数最终决定而不是由__getitem__函数决定,是因为__getitem__函数在collate_fn函数之前运行,
这里面的__getitem__函数可以视为每一次取出数据的时候对数据进行的操作,而collate_fn函数是针对于每一个批次取出数据进行的操作
放入对应的参数,这里使用collate_fn函数让每一个批次以这个批次数值的最大长度作为分界,这个批次如果达不到最大长度的需要在前面的位置补充零。

test1 = [[3],[1,2],[5,6,7],[9]]
test2 = [[6],[4,5],[7,8,9],[10]]
def collate_fn(data):
    #这里的collate_fn是对相应的数据进行处理
    #找出最大长度的数组,如果其他数组达不到相应的长度,
    #就在句子的前面位置补充0
    print('function collate_fn')
    def _pad_sequences(seqs):
        seqs_0 = []
        seqs_1 = []
        for  seq  in  seqs:
            seqs_0.append(seq[0])
            seqs_1.append(seq[1])
        lens0 = [len(seq) for seq in seqs_0]
        lens1 = [len(seq) for seq in seqs_1]
        max_len0 = max(lens0)
        max_len1 = max(lens1)
        max_len = max(max_len0,max_len1)
        padded_seqs0 = torch.zeros(len(seqs_0),max_len).long()
        padded_seqs1 = torch.zeros(len(seqs_1),max_len).long()
        for i  in range(len(seqs_0)):
            seq = seqs_0[i]
            start = max_len-lens0[i]
            padded_seqs0[i,start:] = torch.Tensor(seq)
        for i  in range(len(seqs_1)):
            seq = seqs_1[i]
            start = max_len-lens1[i]
            padded_seqs1[i,start:] = torch.Tensor(seq)
        return  padded_seqs0,padded_seqs1
    data1,data2 = _pad_sequences(data)
    #print('data = ')
    #print(data)
    return data1,data2
dataset = MyDataset(test1,test2)
MyDataLoader = DataLoader(dataset=dataset,shuffle=True,batch_size=4,collate_fn=collate_fn)
for data_iter1,data_iter2 in MyDataLoader:
    print('data_iter1 = ')
    print(data_iter1)
    print('data_iter2 = ')
    print(data_iter2)
    #print('data_iter2 = ')
    #print(data_iter2)

最终MyDataLoader循环之中每次取出来的值由collate_fn的返回值return data1,data2决定

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值