大多显存的溢出,都是因为计算图在进行相乘时运算量太大,显存存不下梯度而导致溢出 而在验证集上是不需要计算梯度的,只有训练集需要更新模型参数,才需要梯度的回传,所以可以在验证集上加一句 with torch.no_grad(): 做个对比实验: 不加 with torch.no_grad(): GPU使用情况 加 with torch.no_grad(): GPU使用情况(差别挺大的)