Python torch.nn.Module.register_forward_pre_hook用法及代码示例

class AddTrainableMask(ABC):

    _tensor_name: str
    
    def __init__(self):
        pass
    
    def __call__(self, module, inputs):

        setattr(module, self._tensor_name, self.apply_mask(module))

    def apply_mask(self, module):

        mask_train = getattr(module, self._tensor_name + "_mask_train")
        mask_fixed = getattr(module, self._tensor_name + "_mask_fixed")
        orig_weight = getattr(module, self._tensor_name + "_orig_weight")
        pruned_weight = mask_train * mask_fixed * orig_weight

        return pruned_weight

    @classmethod
    def apply(cls, module, name, mask_train, mask_fixed, *args, **kwargs):

        method = cls(*args, **kwargs)  
        method._tensor_name = name
        orig = getattr(module, name)

        module.register_parameter(name + "_mask_train", mask_train.to(dtype=orig.dtype))
        module.register_parameter(name + "_mask_fixed", mask_fixed.to(dtype=orig.dtype))
        module.register_parameter(name + "_orig_weight", orig)#这个权重参数的id指向原来的weight
        del module._parameters[name]

        setattr(module, name, method.apply_mask(module))#每次forwar之前都会调用这个钩子,所以每次weight的权重都被直接改了
        module.register_forward_pre_hook(method)

        return method


用法

register_forward_pre_hook(hook)

返回:
一个句柄,可用于通过调用 handle.remove() 删除添加的钩子

返回类型:

torch.utils.hooks.RemovableHandle

在模块上注册一个前向预挂钩。

每次调用 forward() 之前都会调用该钩子。它应该具有以下签名:

hook(module, input) -> None or modified input

输入仅包含给模块的位置参数。关键字参数不会传递给钩子,只会传递给 forward 。钩子可以修改输入。用户可以在钩子中返回一个元组或单个修改值。如果返回单个值,我们会将值包装到一个元组中(除非该值已经是一个元组)。

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值