上一篇讨论了GAN Compression的teacher Generator的训练过程,这一篇继续讨论蒸馏过程,supernet的训练过程和子模型的搜索过程。
蒸馏过程:
bash scripts/cycle_gan/horse2zebra/distill.sh
distill.sh里执行distill.py,依然跳转到trainer.py:
class Trainer:
def __init__(self, task):
if task == 'train':
from options.train_options import TrainOptions as Options
from models import create_model as create_model
elif task == 'distill':
from options.distill_options import DistillOptions as Options
from distillers import create_distiller as create_model
elif task == 'supernet':
from options.supernet_options import SupernetOptions as Options
from supernets import create_supernet as create_model
else:
raise NotImplementedError('Unknown task [%s]!!!' % task)
opt = Options().parse()
只是和上篇不同,此时的task == 'distill',所以trainer.py里面的
- Options其实是DistillOptions。
- create_model其实是create_distiller。
和上篇的思路一样,跳转到DistillOptions这个class的定义,里面定义好了一些蒸馏用到的命令行参数,distiller(默认resnet),netD,ndf等等。
create_distiller的定义是:
def create_distiller(opt, verbose=True):
distiller = find_distiller_using_name(opt.distiller)
instance = distiller(opt)
if verbose:
print("distiller [%s] was created" % type(instance).__name__)
return instance
def find_distiller_using_name(distiller_name):
distiller_filename = "distillers." + distiller_name + '_distiller'
# print(distiller_filename)
modellib = importlib.import_module(distiller_filename)
distiller = None
target_distiller_name = distiller_name.replace('_', '') + 'distiller'
for name, cls in modellib.__dict__.items():
if name.lower() == target_distiller_name.lower():
distiller = cls
if distiller is None:
print("In %s.py, there should be a class of Distiller with class name that matches %s in lowercase." %
(distiller_filename, target_distiller_name))
exit(0)
return distiller
最后create_distiller返回的instance是ResnetDistiller这个class实例化的模型。
ResnetDistiller这个class定义蒸馏过程,前向传播定义为:
def forward(self):
with torch.no_grad():
self.Tfake_B = self.netG_teacher(self.real_A)
self.Sfake_B = self.netG_student(self.real_A)
反向传播:
def backward_G(self):
if self.opt.dataset_mode == 'aligned':
self.loss_G_recon = self.criterionRecon(self.Sfake_B, self.real_B) * self.opt.lambda_recon
fake = torch.cat((self.real_A, self.Sfake_B), 1)
else:
self.loss_G_recon = self.criterionRecon(self.Sfake_B, self.Tfake_B) * self.opt.lambda_recon
fake = self.Sfake_B
pred_fake = self.netD(fake)
self.loss_G_gan = self.criterionGAN(pred_fake, True, for_discriminator=False) * self.opt.lambda_gan
if self.opt.lambda_distill > 0:
self.loss_G_distill = self.calc_distill_loss() * self.opt.lambda_distill
else:
self.loss_G_distill = 0
self.loss_G = self.loss_G_gan + self.loss_G_recon + self.loss_G_distill
self.loss_G.backward()
更新参数:
def optimize_parameters(self):
self.optimizer_D.zero_grad()
self.optimizer_G.zero_grad()
self.forward