来由:
在PyTorch调试时经常需要打印张量,为了查看张量的 大小、类型和设备。
如果直接打印张量,可能一不小心就会打印出一堆乱糟糟的数据,尤其是张量特别大时。这样既不容易看清张量的结构,还污染终端输出,影响调试时的心情☹️。
还有一种情况是不时要调用 t.shape 和 t.device 来观察张量的形状和位置,或者把这两个量加入调试程序的watch中,但总的来说这样做都很麻烦。
于是这里我考虑修改张量默认的打印结果,让张量较长时不再打印其内容,而是依次打印出它的 shape,device和dtype。代码很短,只有三行,且全局开头使用一次后就会一直生效:
import torch
raw_tensor_str = torch.Tensor.__repr__
torch.Tensor.__str__ = torch.Tensor.__repr__ = lambda self: f'Tensor{{Size({[*self.shape]}) {self.device} {str(self.dtype)[6]}{str(self.dtype)[-2:]}}}' if self.numel() > 10 else raw_tensor_str(self)
解释:
这段代码人为地修改了torch中张量的__str__和__repr__方法,这样张量在被打印时就会采取以下逻辑:1.当张量很短,例如元素在10个以下时,会打印其内容 2.如果张量很长,则会打印其 `Tensor{形状 设备 类型}`。
用了以上代码之后,就可以无脑打印张量了,不用担心突如其来的大的张量在终端里“刷屏”。
效果演示:
使用前,张量打印是这样的:
>>> print(t)
tensor([[[-0.8294, -0.0461, -0.0620, ..., -0.6344, 0.2743, -1.4147],
[-0.5543, 0.1475, -0.4420, ..., -1.1269, -1.0268, -0.6499],
[-0.5123, 1.6033, -0.9546, ..., -0.5362, -0.7707, 0.0571],
...,
[ 0.9144, 0.2731, -0.6242, ..., -0.0733, -2.1091, -0.0747],
[-0.2693, 0.5862, -0.5392, ..., -0.3469, 0.6258, -0.5609],
[-1.4400, 1.4466, 1.0845, ..., -0.1689, 0.5124, -0.0134]]])
虽然已经打了省略号了但还是很丑,而且根本看不出形状和设备、类型。
使用后,张量打印是这样的:
>>> print(t)
Tensor{Size([1, 100, 100]) cuda:0 f32}
形状是`Size([1, 100, 100])`,设备是 cuda:0,类型是f32(即float32),是不是清晰很多了呢?
那么之后如何换回以前的打印结果呢?很简单,使用 raw_tensor_str(t) 就行!快去试试吧。