认为参数需要满足一定条件,希望卷积层参数是正交的。
如果不是正交的,计算与正交之间的距离,然后作为损失进行优化。
本程序给出了orthogonal regularization的pytorch的实现,直接返回模型的损失。
import torch
def orthogonal_regularization(model, device, beta=1e-4):
r"""
author: Xu Mingle
time: 2019年2月19日15:12:43
input:
model: which is the model we want to use orthogonal regularization, e.g. Generator or Discriminator
device: cpu or gpu
beta: hyperparameter
output: loss
"""
# beta * (||W^T.W * (1-I)||_F)^2 or
# beta * (||W.W.T * (1-I)||_F)^2
# 若 H < W,可以使用前者, 若 H > W, 可以使用后者,这样可以适当减少内存
loss_orth = torch.tensor(0., dtype=torch.float32, device=device)
for name, param in model.named_parameters():
# print('name is {}'.format(name))
# print('shape is {}'.format(param.shape))
if 'weight' in name and param.requires_grad and len(param.shape)==4:
# 是weight,而不是bias
# 当然是指定被训练的参数
# 只对卷积层参数做这样的正则化,而不包括嵌入层(维度是2)等。
# print('shape is {}'.format(param.shape))
# print('name {}'.format(name))
N, C, H, W = param.shape
# print('param shape {}'.format(param.shape))
weight = param.view(N * C, H, W)
# print('flatten shape {}'.format(weight.shape))
weight_squared = torch.bmm(weight, weight.permute(0, 2, 1)) # (N * C) * H * H
# print('beta_squared shape {}'.format(weight_squared.shape))
ones = torch.ones(N * C, H, H, dtype=torch.float32) # (N * C) * H * H
# print('ones shape {}'.format(ones.shape))
diag = torch.eye(H, dtype=torch.float32) # (N * C) * H * H
# print('diag shape {}'.format(diag.shape))
loss_orth += ((weight_squared * (ones - diag).to(device)) ** 2).sum()
return loss_orth * beta