hook是pytorch用于获取网络某层的输出和反向传播梯度,一共有3种:
register_hook(hook):获取变量的梯度,不能获得神经网络中间的计算量
register_forward_hook(hook):网络中前向传播的量
register_backward_hook(hook):网络中反向传播的量
我自己写的网络,想获得某个模块的反向传播权重,不能用register_hook,所以把该模块写成了nn.module
hook是pytorch用于获取网络某层的输出和反向传播梯度,一共有3种:
register_hook(hook):获取变量的梯度,不能获得神经网络中间的计算量
register_forward_hook(hook):网络中前向传播的量
register_backward_hook(hook):网络中反向传播的量
我自己写的网络,想获得某个模块的反向传播权重,不能用register_hook,所以把该模块写成了nn.module