OOM报错如图所示:
训练跑得好好的,怎么测试的时候反倒报错了呢?
原来是没有加with torch.no_grad(),在测试的时候有gradient容易报内存溢出,加上:
with torch.no_grad():
testY = model(testX)
再跑代码就OK了。
OOM报错如图所示:
训练跑得好好的,怎么测试的时候反倒报错了呢?
原来是没有加with torch.no_grad(),在测试的时候有gradient容易报内存溢出,加上:
with torch.no_grad():
testY = model(testX)
再跑代码就OK了。