Pytorch 2.0.1内存泄漏问题
更新:经过多次测试,已经确信这个BUG只在2.0.x版本中存在,避免使用2.0.x版本即可。此贴完结。
当推理的数据含有大量不同的shape时,会导致内存泄漏。一段发生泄露的代码:
from torchvision.models import resnet
import torch
from memory_profiler import profile
net = resnet.resnet50(pretrained=True)
net = net.cuda()
net.train()
@profile(precision=4,stream=open('resnet.log','w'))
def infer(width, height):
data = torch.randn(2, 3, width, height)
x = data.clone().cuda()
out = net(x)
torch.cuda.empty_cache()
for width in range(100, 2000, 10):
print(width)
for height in range(100, 2000 ,10):
infer(width, height)
用memory_profiler打印一下,内存占用直线上升:
我的环境是pytorch 2.0.1。
个人测试了多次,推测原因是,pytorch似乎对每个shape的tensor都在内存里占用了一定缓存,而且没办法清理。当shape越来越多时,就OOM了。
对于开发者来说,对shape的缓存或许是pytorch性能优化的一部分,但是开发者显然没有考虑到推理用到大量不同shape导致OOM的情况。再加上目前似乎没有官方API来清理这些缓存(如果有的话小伙伴踢我一下),所以我认为应该将之视为一个内存泄漏的bug。
欢迎有兴趣的小伙伴自己跑一下
更希望有大佬帮忙解决问题(灬ꈍ ꈍ灬)