pytorch冻结部分模型

y = f ( x ; θ ) y=f(x;\theta) y=f(x;θ),想要 loss 对 θ \theta θ 的梯度,但保留对 x x x 的梯度,不能直接放进 no_grad() 域,而需要冻结 f f f

Code

import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter


os.system("cls")
N = 4
DIM_X = 2
DIM_Y = 3


class Autoencoder(nn.Module):

    def __init__(self):
        super(Autoencoder, self).__init__()

        self._encoder = nn.Sequential(
            nn.Linear(DIM_X, DIM_Y),
            nn.BatchNorm1d(DIM_Y),
            nn.ReLU()
        )

        self._decoder = nn.Linear(DIM_Y, DIM_X)

    def encoder(self, x):
        return self._encoder(x)

    def decoder(self, latent):
        return self._decoder(latent)

    def forward(self, x):
        latent = self.encoder(x)
        x_hat = self.decoder(latent)
        return x_hat, latent


class TrainVar(nn.Module):

    def __init__(self, *size, init_val=None, process_fn=None):
        super(TrainVar, self).__init__()
        self.size = size
        self.process_fn = process_fn
        if init_val is None:
            self.weight = Parameter(torch.Tensor(*size))
            self.reset_parameters()
        else:
            self.weight = Parameter(init_val * torch.ones(*size, dtype=torch.float))

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))

    def forward(self):
        if self.process_fn is None:
            return self.weight
        else:
            return self.process_fn(self.weight)


def show_param(model):
    for name, param in model.named_parameters():
        print(name, param.size())
    # for param in model.parameters():
        # print(param.size())


def show_grad(model):
    for name, param in model.named_parameters():
        if param.grad is not None:
            print(name, "has grad:", param.grad.size())
        # else:
        #     print(name, "has NO grad")


def freeze_module(*modules):
    """冻结模型"""
    for m in modules:
        for param in m.parameters():
            param.requires_grad = False


def activate_module(*modules):
    """解冻"""
    for m in modules:
        for param in m.parameters():
            param.requires_grad = True


def zero_grad(*modules):
    """手动清梯度"""
    for m in modules:
        for param in m.parameters():
            param.grad = None


ae = Autoencoder()
X = TrainVar(N, DIM_X)

print("--- AE ---")
show_param(ae)
# show_grad(ae)
print("--- X ---")
show_param(X)
# show_grad(X)

idx = np.array([0, 2], dtype="int")


print("\t普通 forward")
x = X()[idx]
x_hat, z = ae(x)
loss = F.mse_loss(x_hat, x)
loss.backward()
print("-- AE")
show_grad(ae)  # 有
print("-- X")
show_grad(X)  # 有

zero_grad(ae, X)
# show_grad(ae)
# show_grad(X)


print("\t套 no_grad 域")
with torch.no_grad():
    x = X()[idx]
    x_hat, z = ae(x)
    loss = F.mse_loss(x_hat, x)
    # loss.backward()  # raise error
print("-- AE")
show_grad(ae)  # 无
print("-- X")
show_grad(X)  # 无


zero_grad(ae, X)


print("\t只冻结 X")
freeze_module(X)
x = X()[idx]  # must be recalled
x_hat, z = ae(x)
loss = F.mse_loss(x_hat, x)
loss.backward()
print("-- AE")
show_grad(ae)  # 有
print("-- X")
show_grad(X)  # 无


zero_grad(ae, X)


print("\t只冻结 AE")
activate_module(X)
freeze_module(ae)
x = X()[idx]  # must be recalled
x_hat, z = ae(x)
loss = F.mse_loss(x_hat, x)
zero_grad(ae, X)
loss.backward()
print("-- AE")
show_grad(ae)  # 无
print("-- X")
show_grad(X)  # 有
  • 输出
--- AE ---
_encoder.0.weight torch.Size([3, 2])
_encoder.0.bias torch.Size([3])
_encoder.1.weight torch.Size([3])
_encoder.1.bias torch.Size([3])
_decoder.weight torch.Size([2, 3])
_decoder.bias torch.Size([2])
--- X ---
weight torch.Size([4, 2])
        普通 forward
-- AE
_encoder.0.weight has grad: torch.Size([3, 2])
_encoder.0.bias has grad: torch.Size([3])
_encoder.1.weight has grad: torch.Size([3])
_encoder.1.bias has grad: torch.Size([3])
_decoder.weight has grad: torch.Size([2, 3])
_decoder.bias has grad: torch.Size([2])
-- X
weight has grad: torch.Size([4, 2])
        套 no_grad 域
-- AE
-- X
        只冻结 X
-- AE
_encoder.0.weight has grad: torch.Size([3, 2])
_encoder.0.bias has grad: torch.Size([3])
_encoder.1.weight has grad: torch.Size([3])
_encoder.1.bias has grad: torch.Size([3])
_decoder.weight has grad: torch.Size([2, 3])
_decoder.bias has grad: torch.Size([2])
-- X
        只冻结 AE
-- AE
-- X
weight has grad: torch.Size([4, 2])

References

  1. Autograd mechanics
  2. Optimizer.zero_grad
  3. Module.zero_grad
  4. 按名获取pytorch模型的各参数
  5. iTomxy/ml-template/wheel/pytorch/wheel.py
  • 1
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值