【DEBUG】报错RuntimeError: Trying to resize storage that is not resizable 解决思路分享

【DEBUG】报错RuntimeError: Trying to resize storage that is not resizable 解决思路分享

问题来源

问题跟这个博主的类似 RuntimeError: Trying to resize storage that is not resizable

我认为原因是你dataloader当中存在数据shape大小不一致。

比如假设你 dataset getitem 写好了返回一个 tensor。假设shape是 (100, 100), 那么报错的原因就是存在一种情况你的 getitem 返回了一个 shape 不是 (100,100)而是另一个不同形状的的 tensor。

这个时候 pytorch 没法拼接成 [B, C, X, …] Batch 大小的 tensor 用于后续训练。(假设你是batch是1,应该就会在训练的时候报错)

解决方式

(省流)最高效的办法是重新读入一遍 dataloader, 人为的 iter 一遍,检查形状不一致的情况;

需要说明的是,这个问题不是你 dataloader return 数据有问题,而是你数据的形状有问题,所以应该是要假定 return 出来的数据是不是你预想中的大小。

(这是我踩完坑觉得最高效的方法,请一定首先检查是不是形状的问题,再考虑是不是 num_worker 跟 gpu 数量对不上的问题)




下面是我自己具体分析这个问题的过程,权供网友们参考。

在分析之前,包含了以下的知识点:

  • ddp 装饰显示 error
  • 自定义 collate_function 打印错误
  • 使用 pdb 分析

我首先想到的就是在主进程中debug,因为可能dataloader加载的方式很复杂,写的很臃肿(比如我),或者定位不到问题(写成了 ddp),下面单独说下ddp情况要先加什么代码:

DDP:

# Pytorch ddp 由于多卡的问题,不能把 error tracelace 打印出来,这样就不知道具体的报错
# 第一步就是把你的整个训练代码写在 main() 函数下面
# 第二步是用 record 这个函数装饰一下 main()


from torch.distributed.elastic.multiprocessing.errors import record

@record
def main():
	# 把你的代码放进来
	pass

这样你运行 torch distributed lanch 就会具体报错,报的还是 RuntimeError: Trying to resize storage that is not resizable.

DDP & DP

Pytorch Dataloader 有一个很大的问题就是在 vscode 里面打不了断点,不方便debug。

下面首先就是怎么定位到有问题的地方,既然是 collate function 报错,那么我们就自定义一个进去找问题

from torch.utils.data._utils.collate import default_collate

def custom_collate_fn(batch):
    try:
        return default_collate(batch)
    except RuntimeError as e:
        print(f"Collate error: {e}")
        # Handle the error or skip the batch
        for iter, items in enumerate(batch):
        	for item in items:
        		print(item.shape)
        return None

这个时候,运行完程序,应该就会有一个报错,然后有具体的shape,应该就能明白问题了。假如还不够清晰,可以加一个 assert 或者 if else 来进一步确认,我肉眼一个一个比对看到了问题。

下一步是找到具体是哪一个 data load 进来有问题,确认我是全部的 data 设置有问题还是只是单个的,很直觉的想法就是找到对应的 dataset 在 dataloader 过程中究竟是哪一个 idx 报错。 方法也很简单,用 pdb set 一个 breakpoint 然后手动 debug 一下就可以。

from torch.utils.data._utils.collate import default_collate

def custom_collate_fn(batch):
    try:
        return default_collate(batch)
    except RuntimeError as e:
    	breakpoint()# 手动设置一个 breakpoint
        print(f"Collate error: {e}")
        # Handle the error or skip the batch
        for iter, items in enumerate(batch):
        	for item in items:
        		print(item.shape)
        return None

# Dataloader 这里也要详细设置一下,把 num_worker 置 0,让程序在主进程运行
your_loader = DataLoader(...,
                        num_workers=0,
                        ...)

假设文件名是 main.py

python main.py

运行的时候,报错就会进入 pdb,要简单了解的话可以参考这个 b 站视频 码农高天-10分钟入门pdb,但其实思路还是比较简单的,总体上记住 s 代表 step,where 代表运行到哪里,u 代表上溯,d 代表下溯,想看什么变量直接打变量名就行。

经过这样一番debug后,我也是确认了 idx 的具体值,然后进入 dataloader 单独验证发现了问题,并修复。

  • 19
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值