Pytorch代码调试工具--torchsnooper

导言:

    “RuntimeError: Expected object of scalar type Double but got scalar type Float”,这样的错误想必无数次在运行时出现,代码调试是一个十分头疼的问题,头疼归头疼,但总要解决。有些错误从提示就能看出在哪有问题,但要解决这个问题,却不一定是在提示的地方改代码。

    解决这类问题的最好方法就是在输出完整的过程,查看运行过程中每一行代码中参数的类型,shape等。从而快速、精确定位到需要改的地方。

torch snooper

Pytorch有一个十分好用的工具--torchsnooper,在可能出现bug的函数前加一个声明,即可在运行过程中输出这个函数每行代码的所有信息。

    安装:

    pip install torchsnooper

    还有另一个同样的工具-- snoop。

    pip install snoop

示例:

import torch

def myfunc(mask, x):
    y = torch.zeros(6)
    y.masked_scatter_(mask, x)
return y

mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda')
source = torch.tensor([1.0, 2.0, 3.0], device='cuda')
y = myfunc(mask, source)

运行后出现以下问题:

RuntimeError: Expected object of backend CPU but got backend CUDA for argument #2 'mask'

解决办法:使用torchsnooper,在代码前加入import torchsnooper 和@torchsnooper.snoop()

import torch
import torchsnooper
@torchsnooper.snoop()

def myfunc(mask, x):
    y = torch.zeros(6)
    y.masked_scatter_(mask, x)
    return y
mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda')
source = torch.tensor([1.0, 2.0, 3.0], device='cuda')
y = myfunc(mask, source)

运行将会出现以下内容:

Starting var:.. mask = tensor<(6,), int64, cuda:0>
Starting var:.. x = tensor<(3,), float32, cuda:0>
21:41:42.941668 call         5 def myfunc(mask, x):
21:41:42.941834 line         6     y = torch.zeros(6)
New var:....... y = tensor<(6,), float32, cpu>
21:41:42.943443 line         7     y.masked_scatter_(mask, x)
21:41:42.944404 exception    7     y.masked_scatter_(mask, x)

从提示中可以看出Y是一个在CPU上的tensor, 因此可以将y改为

y = torch.zeros(6, device='cuda')

再次运行将会出现新的问题:

RuntimeError: Expected object of scalar type Byte but got scalar type Long for argument #2 'mask'

scalar的类型应为int,但却用的是long。在上方torchsnooper输出的提示中可以看出mask的类型为int64,从而定位出问题在这里,将其改为uint8就可以了.

mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda', dtype=torch.uint8)

snooper

下面介绍另一种方法,本人一直用的是这种。

在可能出现bug的函数前使用声明@snoop,在执行到这个函数时将会显示这个函数的所有信息。

import torch
import torchsnooper
import snoop
torchsnooper.register_snoop()#在文件前面调用这个函数

@snoop #把这个声明放在想要输出的函数前
def myfunc(mask, x):
    y = torch.zeros(6)
    y.masked_scatter_(mask, x)
return y
mask = torch.tensor([0, 1, 0, 1, 1, 0], device='cuda')
source = torch.tensor([1.0, 2.0, 3.0], device='cuda')
y = myfunc(mask, source)

snoop的另一种用法:使用with torchsnooper.snoop()

with torchsnooper.snoop():
    for _ in range(100):
        optimizer.zero_grad()
        pred = model(x)
        squared_diff = (y - pred) ** 2
        loss = squared_diff.mean()
        print(loss.item())
        loss.backward()
        optimizer.step()

输出如下:

New var:....... x = tensor<(4, 2), float32, cpu>
New var:....... y = tensor<(4,), float32, cpu>
New var:....... model = Model(  (layer): Linear(in_features=2, out_features=1, bias=True))
New var:....... optimizer = SGD (Parameter Group 0    dampening: 0    lr: 0....omentum: 0    nesterov: False    weight_decay: 0)
22:27:01.024233 line        21     for _ in range(100):
New var:....... _ = 0
22:27:01.024439 line        22         optimizer.zero_grad()
22:27:01.024574 line        23         pred = model(x)
New var:....... pred = tensor<(4, 1), float32, cpu, grad>
22:27:01.026442 line        24         squared_diff = (y - pred) ** 2
New var:....... squared_diff = tensor<(4, 4), float32, cpu, grad>
22:27:01.027369 line        25         loss = squared_diff.mean()
New var:....... loss = tensor<(), float32, cpu, grad>
22:27:01.027616 line        26         print(loss.item())
22:27:01.027793 line        27         loss.backward()
22:27:01.050189 line        28         optimizer.step()

y的shape为(4,)而pred的shape为(4,1)。从而定位出问题在于pred多了一维,在代码中加一行pred = model(x).squeeze()即可解决问题。

引用:https://github.com/zasdfgbnm/TorchSnooper

 本文来源于公众号 CV技术指南 的技术总结系列。

欢迎关注公众号 CV技术指南 ,专注于计算机视觉的技术总结、最新技术跟踪、经典论文解读。

 在公众号中回复关键字 “技术总结” 可获取以下文章的汇总pdf。

其它文章

北京大学施柏鑫:从审稿人视角,谈谈怎么写一篇CVPR论文

Siamese network总结

计算机视觉专业术语总结(一)构建计算机视觉的知识体系

欠拟合与过拟合技术总结

归一化方法总结

论文创新的常见思路总结

CV方向的高效阅读英文文献方法总结

计算机视觉中的小样本学习综述   

知识蒸馏的简要概述   

优化OpenCV视频的读取速度

NMS总结   

损失函数技术总结

注意力机制技术总结   

特征金字塔技术总结   

池化技术总结

数据增强方法总结   

CNN结构演变总结(一)经典模型

CNN结构演变总结(二)轻量化模型 

CNN结构演变总结(三)设计原则

如何看待计算机视觉未来的走向   

CNN可视化技术总结(一)-特征图可视化

CNN可视化技术总结(二)-卷积核可视化

CNN可视化技术总结(三)-类可视化

CNN可视化技术总结(四)-可视化工具与项目

  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值