对于 nn.Module
可以使用 nn.Module.requires_grad_(False)
原理
torch 前向传播会记录每个 tensor 的计算图,反向传播会根据 tensor 的 requires_grad: bool 来决定是否计算梯度
通过设置 .requires_grad_(False),使网络递归的将其参数(tensor)设置为 requires_grad = False, 从而不计算梯度
注意
建议复写 .train 方法
def train(mode):
if not self.trainable:
return self.eval()