try:
outputs = (model(train_datas)).cuda()
except RuntimeError as exception:
if "out of memory" in str(exception):
print("WARNING: out of memory")
if hasattr(torch.cuda, 'empty_cache'):
torch.cuda.empty_cache()
else:
raise exception
单独使用 torch.cuda.empty_cache() 不会奏效。