def debug_memory():
import collections, gc, resource, torch
print('maxrss = {}'.format(
resource.getrusage(resource.RUSAGE_SELF).ru_maxrss))
tensors = collections.Counter((str(o.device), o.dtype, tuple(o.shape))
for o in gc.get_objects()
if torch.is_tensor(o))
for line in sorted(tensors.items()):
print('{}\t{}'.format(*line))
使用上面的函数,在training loop 里用,可以追踪占用显寸的变量的大小,从而发现一直在扩大的、非正常占用显存的问题变量。
下面是一个例子(这里面没有用gpu,都是cpu)
>>> z = [torch.randn(i).long() for i in range(10)]
>>> debug_memory()
('cpu', torch.float32, (3, 3)) 2
('cpu', torch.int64, (0,)) 1
('cpu', torch.int64, (1,)) 1
('cpu', torch.int64, (2,)) 1
('cpu', torch.int64, (3,)) 1
('cpu', torch.int64, (4,)) 1
('cpu', torch.int64, (5,)) 1
('cpu', torch.int64, (6,)) 1
('cpu', torch.int64, (7,)) 1
('cpu', torch.int64, (8,)) 1
('cpu', torch.int64, (9,)) 1
来源:A clever trick to debug tensor memory - Misc. - Pyro Discussion Forum