网络在前期可以正常训练,但训练几轮后就发生显存爆炸的问题,调整输入大小或者每次循环都清除显存 也无法解决问题,后来经过查询,是在对loss求和时,直接使用
tl += loss
可以看到,loss是张量,经过运算后,tl也是张量,在神经网络中,pytorch会默认将张量操作放到计算图中,随着训练次数的增加,计算图会越来越大,直至显存爆炸。
解决办法:
tl += loss.item()
计算图原理:
计算图中每个节点代表一个输入,每条边代表一个运算操作。例如y=(a+b)(b+c),则a,b,c都是节点,之后a+b在连接到一个节点,b+c连接到一个节点,最后两节点连接输出。
pytorch是动态建立计算图,边建立边计算。
tensorflow是静态建立计算图