之前写代码的时候遇到的一个问题,一直没有解决,后来稀里糊涂的解决了,我也不知道原因,这里贴出来,希望大家遇到这个问题的时候能有些启发,图来自网上搜索,由于问题是很久以前的了,当时没有保存截图,抱歉了。
这个问题的出现其实可以将 loss.backward() 写为 loss.backward(retain_graph=True) 来解决的,但是这样会增加内存的消耗,后来我仔细去输出了我的变量类型,确实都是叶子节点,按道理来说是不会被释放的,这里我的loss.backward()也只使用了一次,所以可以确定是不存在被释放的,后面也输出了其状态,发现其在loss.backward()前均是叶子节点,也存在梯度。如下图所示,四个参数打印出来均是叶子节点,这里由于没有时间运行,就没有展示。
后来在别人发来的链接中突然受到启发,将最后的return a11改为了return a11.data ,这样算法就没有报错了。不知道这个原理是什么,可能是return 返回 a11后,进行了叶子节点的计算吧,由于代码能跑通了,所以也就没有仔细想了,如果有知道的希望告知一下,谢谢。下图是调用的函数以及函数返回值。
所有参数均为叶子节点,返回a11则报错,返回a11.data则问题消失。