class AddTrainableMask(ABC):
_tensor_name: str
def __init__(self):
pass
def __call__(self, module, inputs):
setattr(module, self._tensor_name, self.apply_mask(module))
def apply_mask(self, module):
mask_train = getattr(module, self._tensor_name + "_mask_train")
mask_fixed = getattr(module, self._tensor_name + "_mask_fixed")
orig_weight = getattr(module, self._tensor_name + "_orig_weight")
pruned_weight = mask_train * mask_fixed * orig_weight
return pruned_weight
@classmethod
def apply(cls, module, name, mask_train, mask_fixed, *args, **kwargs):
method = cls(*args, **kwargs)
method._tensor_name = name
orig = getattr(module, name)
module.register_parameter(name + "_mask_train", mask_train.to(dtype=orig.dtype))
module.register_parameter(name + "_mask_fixed", mask_fixed.to(dtype=orig.dtype))
module.register_parameter(name + "_orig_weight", orig)#这个权重参数的id指向原来的weight
del module._parameters[name]
setattr(module, name, method.apply_mask(module))#每次forwar之前都会调用这个钩子,所以每次weight的权重都被直接改了
module.register_forward_pre_hook(method)
return method
用法
register_forward_pre_hook(hook)
返回:
一个句柄,可用于通过调用 handle.remove() 删除添加的钩子
返回类型:
torch.utils.hooks.RemovableHandle
在模块上注册一个前向预挂钩。
每次调用 forward() 之前都会调用该钩子。它应该具有以下签名:
hook(module, input) -> None or modified input
输入仅包含给模块的位置参数。关键字参数不会传递给钩子,只会传递给 forward 。钩子可以修改输入。用户可以在钩子中返回一个元组或单个修改值。如果返回单个值,我们会将值包装到一个元组中(除非该值已经是一个元组)。