hook()函数的作用很强大,pytorch中通常会自动舍弃图计算的中间结果,所以想要获取网络中间层的输出结果或者某些变量的梯度,就可以使用Hook函数来实现,hook函数包括tensor的hook和nn.Module的hook,用法相似。hook函数主要有x.register_hook(hook), layer.register_forward_hook()和layer.register_backward_hook(),x是模型的参数,第一个主要用于获得x的梯度信息,后面两个主要用于模型前向和后向运行中获取输入输出的结果,申明Hook函数必须在模型训练的前面,然后更重要的是:使用完hook后,要及时删除,不然很容易造成hook函数结果累积,使得显存爆掉。可使用hook.remove()删除。
使用register_hook()函数导致显存溢出的问题
最新推荐文章于 2023-04-20 21:09:41 发布