epoch训练时间不同_GAN Compression 8.训练过程分析2

e959f9d7a16f987c3975cfa88dc2490d.png
科技猛兽:GAN Compression 7.训练过程分析​zhuanlan.zhihu.com
7c08aba7055405a9a0e7583e1d2715d1.png

上一篇讨论了GAN Compression的teacher Generator的训练过程,这一篇继续讨论蒸馏过程supernet的训练过程子模型的搜索过程

蒸馏过程:

bash scripts/cycle_gan/horse2zebra/distill.sh

distill.sh里执行distill.py,依然跳转到trainer.py:

class Trainer:
    def __init__(self, task):
        if task == 'train':
            from options.train_options import TrainOptions as Options
            from models import create_model as create_model
        elif task == 'distill':
            from options.distill_options import DistillOptions as Options
            from distillers import create_distiller as create_model
        elif task == 'supernet':
            from options.supernet_options import SupernetOptions as Options
            from supernets import create_supernet as create_model
        else:
            raise NotImplementedError('Unknown task [%s]!!!' % task)
        opt = Options().parse()

只是和上篇不同,此时的task == 'distill',所以trainer.py里面的

  • Options其实是DistillOptions。
  • create_model其实是create_distiller。

和上篇的思路一样,跳转到DistillOptions这个class的定义,里面定义好了一些蒸馏用到的命令行参数,distiller(默认resnet),netD,ndf等等。

create_distiller的定义是:

def create_distiller(opt, verbose=True):
    distiller = find_distiller_using_name(opt.distiller)
    instance = distiller(opt)
    if verbose:
        print("distiller [%s] was created" % type(instance).__name__)
    return instance
def find_distiller_using_name(distiller_name):
    distiller_filename = "distillers." + distiller_name + '_distiller'
    # print(distiller_filename)
    modellib = importlib.import_module(distiller_filename)
    distiller = None
    target_distiller_name = distiller_name.replace('_', '') + 'distiller'
    for name, cls in modellib.__dict__.items():
        if name.lower() == target_distiller_name.lower():
            distiller = cls

    if distiller is None:
        print("In %s.py, there should be a class of Distiller with class name that matches %s in lowercase." %
              (distiller_filename, target_distiller_name))
        exit(0)

    return distiller

最后create_distiller返回的instance是ResnetDistiller这个class实例化的模型。

ResnetDistiller这个class定义蒸馏过程,前向传播定义为:

    def forward(self):
        with torch.no_grad():
            self.Tfake_B = self.netG_teacher(self.real_A)
        self.Sfake_B = self.netG_student(self.real_A)

反向传播:

    def backward_G(self):
        if self.opt.dataset_mode == 'aligned':
            self.loss_G_recon = self.criterionRecon(self.Sfake_B, self.real_B) * self.opt.lambda_recon
            fake = torch.cat((self.real_A, self.Sfake_B), 1)
        else:
            self.loss_G_recon = self.criterionRecon(self.Sfake_B, self.Tfake_B) * self.opt.lambda_recon
            fake = self.Sfake_B
        pred_fake = self.netD(fake)
        self.loss_G_gan = self.criterionGAN(pred_fake, True, for_discriminator=False) * self.opt.lambda_gan
        if self.opt.lambda_distill > 0:
            self.loss_G_distill = self.calc_distill_loss() * self.opt.lambda_distill
        else:
            self.loss_G_distill = 0
        self.loss_G = self.loss_G_gan + self.loss_G_recon + self.loss_G_distill
        self.loss_G.backward()

更新参数:

    def optimize_parameters(self):
        self.optimizer_D.zero_grad()
        self.optimizer_G.zero_grad()
        self.forward
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值