y
=
f
(
x
;
θ
)
y=f(x;\theta)
y=f(x;θ),想不要 loss 对
θ
\theta
θ 的梯度,但保留对
x
x
x 的梯度,不能直接放进 no_grad()
域,而需要冻结
f
f
f。
Code
import os
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
os.system("cls")
N = 4
DIM_X = 2
DIM_Y = 3
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self._encoder = nn.Sequential(
nn.Linear(DIM_X, DIM_Y),
nn.BatchNorm1d(DIM_Y),
nn.ReLU()
)
self._decoder = nn.Linear(DIM_Y, DIM_X)
def encoder(self, x):
return self._encoder(x)
def decoder(self, latent):
return self._decoder(latent)
def forward(self, x):
latent = self.encoder(x)
x_hat = self.decoder(latent)
return x_hat, latent
class TrainVar(nn.Module):
def __init__(self, *size, init_val=None, process_fn=None):
super(TrainVar, self).__init__()
self.size = size
self.process_fn = process_fn
if init_val is None:
self.weight = Parameter(torch.Tensor(*size))
self.reset_parameters()
else:
self.weight = Parameter(init_val * torch.ones(*size, dtype=torch.float))
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
def forward(self):
if self.process_fn is None:
return self.weight
else:
return self.process_fn(self.weight)
def show_param(model):
for name, param in model.named_parameters():
print(name, param.size())
# for param in model.parameters():
# print(param.size())
def show_grad(model):
for name, param in model.named_parameters():
if param.grad is not None:
print(name, "has grad:", param.grad.size())
# else:
# print(name, "has NO grad")
def freeze_module(*modules):
"""冻结模型"""
for m in modules:
for param in m.parameters():
param.requires_grad = False
def activate_module(*modules):
"""解冻"""
for m in modules:
for param in m.parameters():
param.requires_grad = True
def zero_grad(*modules):
"""手动清梯度"""
for m in modules:
for param in m.parameters():
param.grad = None
ae = Autoencoder()
X = TrainVar(N, DIM_X)
print("--- AE ---")
show_param(ae)
# show_grad(ae)
print("--- X ---")
show_param(X)
# show_grad(X)
idx = np.array([0, 2], dtype="int")
print("\t普通 forward")
x = X()[idx]
x_hat, z = ae(x)
loss = F.mse_loss(x_hat, x)
loss.backward()
print("-- AE")
show_grad(ae) # 有
print("-- X")
show_grad(X) # 有
zero_grad(ae, X)
# show_grad(ae)
# show_grad(X)
print("\t套 no_grad 域")
with torch.no_grad():
x = X()[idx]
x_hat, z = ae(x)
loss = F.mse_loss(x_hat, x)
# loss.backward() # raise error
print("-- AE")
show_grad(ae) # 无
print("-- X")
show_grad(X) # 无
zero_grad(ae, X)
print("\t只冻结 X")
freeze_module(X)
x = X()[idx] # must be recalled
x_hat, z = ae(x)
loss = F.mse_loss(x_hat, x)
loss.backward()
print("-- AE")
show_grad(ae) # 有
print("-- X")
show_grad(X) # 无
zero_grad(ae, X)
print("\t只冻结 AE")
activate_module(X)
freeze_module(ae)
x = X()[idx] # must be recalled
x_hat, z = ae(x)
loss = F.mse_loss(x_hat, x)
zero_grad(ae, X)
loss.backward()
print("-- AE")
show_grad(ae) # 无
print("-- X")
show_grad(X) # 有
- 输出
--- AE ---
_encoder.0.weight torch.Size([3, 2])
_encoder.0.bias torch.Size([3])
_encoder.1.weight torch.Size([3])
_encoder.1.bias torch.Size([3])
_decoder.weight torch.Size([2, 3])
_decoder.bias torch.Size([2])
--- X ---
weight torch.Size([4, 2])
普通 forward
-- AE
_encoder.0.weight has grad: torch.Size([3, 2])
_encoder.0.bias has grad: torch.Size([3])
_encoder.1.weight has grad: torch.Size([3])
_encoder.1.bias has grad: torch.Size([3])
_decoder.weight has grad: torch.Size([2, 3])
_decoder.bias has grad: torch.Size([2])
-- X
weight has grad: torch.Size([4, 2])
套 no_grad 域
-- AE
-- X
只冻结 X
-- AE
_encoder.0.weight has grad: torch.Size([3, 2])
_encoder.0.bias has grad: torch.Size([3])
_encoder.1.weight has grad: torch.Size([3])
_encoder.1.bias has grad: torch.Size([3])
_decoder.weight has grad: torch.Size([2, 3])
_decoder.bias has grad: torch.Size([2])
-- X
只冻结 AE
-- AE
-- X
weight has grad: torch.Size([4, 2])