RuntimeError: each element in list of batch should be of equal size

RuntimeError: each element in list of batch should be of equal size

示例代码:

import os
import re
from torch.utils.data import Dataset, DataLoader

data_base_path = r'./aclImdb/'


#  1.定义token的方法
def tokenize(test):
    filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','<','=','>','\?','@'
        ,'\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]
    text = re.sub("<.*?>", " ", test, flags=re.S)
    text = re.sub("|".join(filters), " ", test, flags=re.S)
    return [i.strip() for i in text.split()]


#  2.准备dataset
class ImdbDataset(Dataset):
    def __init__(self, mode):
        super().__init__()
        if mode == "train":
            text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]
        else:
            text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]
        self.total_file_path_list = []
        for i in text_path:
            self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])

    def __getitem__(self, item):
        cur_path = self.total_file_path_list[item]
        cur_filename = os.path.basename(cur_path)
        label = int(cur_filename.split("_")[-1].split(".")[0]) - 1  # 处理标题,获取标签label,转化为从[0-9]
        text = tokenize(open(cur_path).read().strip())  # 直接按照空格进行分词
        return label, text

    def __len__(self):
        return len(self.total_file_path_list)


#  3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True)

#  4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):
    print("idx:", idx)
    print("label:", label)
    print("text:", text)
    break

运行结果:

报错原因:

dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True),发现是这行代码导致的错误,如果把batch_size=2改为batch_size=1时就不再报错了,运行结果如下:

但是如果想让batch_size=2时,这个错误该如何解决呢?

解决方法如下:

出现问题的原因在于Dataloader中的参数collate_fn

collate_fn的默认值为torch自定义的default_collate,collate_fn的作用就是对每个batch进行处理,而默认的default_collate处理出错。

解决问题的思路:

  • 手段1:考虑先把数据转化为数字序列,观察其结果是否符合要求,之前使用DataLoader并未出现类似错误
  • 手段2:考虑自定义一个collate_fn,观察结果

这里使用方式2,自定义一个collate_fn,然后观察结果:

def collate_fn(batch):
    #  batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
    batch = list(zip(*batch))
    labels = torch.tensor(batch[0], dtype=torch.int32)
    texts = batch[1]
    del batch
    return labels, texts

全部代码:

import os
import re
import torch
from torch.utils.data import Dataset, DataLoader

data_base_path = r'./aclImdb/'


#  1.定义token的方法
def tokenize(test):
    filters = ['!','"','#','$','%','&','\(','\)','\*','\+',',','-','\.','/',':',';','<','=','>','\?','@'
        ,'\[','\\','\]','^','_','`','\{','\|','\}','~','\t','\n','\x97','\x96','”','“',]
    text = re.sub("<.*?>", " ", test, flags=re.S)
    text = re.sub("|".join(filters), " ", test, flags=re.S)
    return [i.strip() for i in text.split()]


#  2.准备dataset
class ImdbDataset(Dataset):
    def __init__(self, mode):
        super().__init__()
        if mode == "train":
            text_path = [os.path.join(data_base_path, i) for i in ["train/neg", "train/pos"]]
        else:
            text_path = [os.path.join(data_base_path, i) for i in ["test/neg", "test/pos"]]
        self.total_file_path_list = []
        for i in text_path:
            self.total_file_path_list.extend([os.path.join(i, j) for j in os.listdir(i)])

    def __getitem__(self, item):
        cur_path = self.total_file_path_list[item]
        cur_filename = os.path.basename(cur_path)
        label = int(cur_filename.split("_")[-1].split(".")[0]) - 1  # 处理标题,获取标签label,转化为从[0-9]
        text = tokenize(open(cur_path).read().strip())  # 直接按照空格进行分词
        return label, text

    def __len__(self):
        return len(self.total_file_path_list)


def collate_fn(batch):
    #  batch是一个列表,其中是一个一个的元组,每个元组是dataset中_getitem__的结果
    batch = list(zip(*batch))
    labels = torch.tensor(batch[0], dtype=torch.int32)
    texts = batch[1]
    del batch
    return labels, texts


#  3.实例化,准别dataloader
dataset = ImdbDataset(mode="train")
dataloader = DataLoader(dataset=dataset, batch_size=2, shuffle=True, collate_fn=collate_fn)

#  4.观察数输出结果
for idx, (label, text) in enumerate(dataloader):
    print("idx:", idx)
    print("label:", label)
    print("text:", text)
    break

运行效果:

  • 19
    点赞
  • 41
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值