三行代码优化PyTorch调试体验:让打印张量更加简短清晰

来由:

在PyTorch调试时经常需要打印张量,为了查看张量的 大小、类型和设备。

如果直接打印张量,可能一不小心就会打印出一堆乱糟糟的数据,尤其是张量特别大时。这样既不容易看清张量的结构,还污染终端输出,影响调试时的心情☹️

还有一种情况是不时要调用 t.shape 和 t.device 来观察张量的形状和位置,或者把这两个量加入调试程序的watch中,但总的来说这样做都很麻烦。

于是这里我考虑修改张量默认的打印结果,让张量较长时不再打印其内容,而是依次打印出它的 shape,device和dtype。代码很短,只有三行,且全局开头使用一次后就会一直生效:

import torch
raw_tensor_str = torch.Tensor.__repr__
torch.Tensor.__str__ = torch.Tensor.__repr__ = lambda self: f'Tensor{{Size({[*self.shape]}) {self.device} {str(self.dtype)[6]}{str(self.dtype)[-2:]}}}' if self.numel() > 10 else raw_tensor_str(self)

解释:

这段代码人为地修改了torch中张量的__str__和__repr__方法,这样张量在被打印时就会采取以下逻辑:1.当张量很短,例如元素在10个以下时,会打印其内容 2.如果张量很长,则会打印其 `Tensor{形状 设备 类型}`。

用了以上代码之后,就可以无脑打印张量了,不用担心突如其来的大的张量在终端里“刷屏”。

效果演示:

使用前,张量打印是这样的:

>>> print(t)
tensor([[[-0.8294, -0.0461, -0.0620,  ..., -0.6344,  0.2743, -1.4147],
         [-0.5543,  0.1475, -0.4420,  ..., -1.1269, -1.0268, -0.6499],
         [-0.5123,  1.6033, -0.9546,  ..., -0.5362, -0.7707,  0.0571],
         ...,
         [ 0.9144,  0.2731, -0.6242,  ..., -0.0733, -2.1091, -0.0747],
         [-0.2693,  0.5862, -0.5392,  ..., -0.3469,  0.6258, -0.5609],
         [-1.4400,  1.4466,  1.0845,  ..., -0.1689,  0.5124, -0.0134]]])

虽然已经打了省略号了但还是很丑,而且根本看不出形状和设备、类型。

使用后,张量打印是这样的:

>>> print(t)
Tensor{Size([1, 100, 100]) cuda:0 f32}

形状是`Size([1, 100, 100])`,设备是 cuda:0,类型是f32(即float32),是不是清晰很多了呢?

那么之后如何换回以前的打印结果呢?很简单,使用 raw_tensor_str(t) 就行!快去试试吧。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值