一般会在训练的时候报错,因为代码中通常有torch.backends.cudnn.benchmark = True做加速,
可能是torch版本的问题,我的版本是2.0.1,换版本后问题还在,其实根据提示把所有与他有关的代码注释掉就可以了。
# torch.backends.cuda.matmul.allow_tf32 = False
# torch.backends.cudnn.benchmark = True
# torch.backends.cudnn.deterministic = False
# torch.backends.cudnn.allow_tf32 = True
这一系列代码全部都注释掉,问题解决。
实践过程中如果爆显存了,有哪些解决办法:
1.batch逐渐减小试试,64->32->16->8
2. torch.cuda.empty_cache() 删除碎片化的内存可以试试
3.换电脑或服务器(bushi