基于CycleGAN的风格迁移---用debug对CycleGAN源码解读

debug使用的简单介绍debug的简单使用

我们来看看这个代码构成

  • opt = TrainOptions().parse() 继承baseoption并执行parse()方法
    • BaseOptions的init()
    • BaseOptions的parse()
      • 在parse()里执行BaseOptions()的gather_options()方法,这玩意大概意思是把base,train,model,data的参数(add_argument)都整合到一块
        • TrainOption的initialize()方法
          • BaseOption的initialize()方法,添加base的parser.add_argument
          • 添加train的parser.add_argument
        • models.get_option_setter()用来追加和model有关的parser.add_argument
          • find_model_using_name() 寻找这个模型有没有定义
            • 本来这个函数也挺正常的,但是调试的时importlib.import_module(model_filename)
              在这里插入图片描述
              说下modellib = importlib.import_module(model_filename)

这个憨憨直接遍历model_filename模型文件里面所有用到的类!!!

然后把所有遍历到的模型名称输出到一个特殊变量 modellib 里

这也导致这个憨憨在使用的时候函数堆栈指针对塞进去一坨奇怪的东西,而且这玩意是在遍历项目的目录,导致step into my code在疯狂的跳新宝岛!!!

  • opt = TrainOptions().parse()
    • BaseOptions.parse()
      • BaseOptions.gather_options()**
        • get_option_setter()
          • find_model_using_name()
            • 说完modellib = importlib.import_module(model_filename)咱们继续
            • 寻找modellib里的模型名和target_model_name相等,并且是BaseModel子类的模型
            • 把这个模型给个叫model的变量并返回
            • 返回刚才找到模型的.modify_commandline_options方法 命名为model_option_setter
          • 执行刚才好不容易返回的model_option_setter ,添加model的parser.add_argument
        • get_option_setter() 和模型的套路基本一致,只不过这次实在数据集的类里找
          • find_dataset_using_name()
            • 找到数据集的类
        • dataset_option_setter,添加dataset的parser.add_argument
      • 返回一块包含所有参数的命名空间给opt
    • 打印并再设置一些GPU的参数返回
  • dataset = create_dataset(opt) 创建数据集
    • CustomDatasetDataLoader() 喜闻乐见的dataloader
      • find_dataset_using_name() ?!老兄你不对劲啊!!!
      • 返回 dataset_class 并实例化这个类 # 好的,dataset已经成为你的对象了
        • make_dataset() 这玩意返回对应目录下所有图片的路径,并组成一个list
        • get_transform() 关于图像的预处理,只不过这里封到函数里了
        • 除了__init__(),下面的__getitem__()也建议看看,那里说的是后面取图象是怎么取的
      • 并设置dataloader
    • 用load_data()方法,把CustomDatasetDataLoader()传出去
  • model = create_model(opt) 创建模型 类似数据集
    • find_model_using_name() 我就喜欢你这种让我跳过的
    • 实例化这个类 # 好的instance也变成对象了,现在…….
      • 先BaseModel.init()

      • 设置8个 loss_names

      • 设置8个 visual_names

      • 设置4个 model_names

      • self.netG_A = networks.define_G() 定义netG 的模型

        • get_norm_layer() 设置norm_layer
        • 根据设置,选择 ResnetGenerator()
          • 网络定义和forward基本都在这,对于现在只有一个res的样子可以参考这里
        • 然后执行init_net()
          • 设置GPU并使用init_weights()初始化参数
      • self.netG_B = networks.define_G() 同上

      • self.netD_A = networks.define_D() 定义netD 的模型

        • get_norm_layer() 设置norm_layer
        • 根据设置,选择 NLayerDiscriminator()
          • 网络定义和forward基本都在这,对于现在只有一个res的样子可以参考这里
        • init_net()同上
      • self.netD_B = networks.define_D() 同上

      • ImagePool() 很奇怪的东西 创建图像缓冲区以存储先前生成的图像

        • 这个缓冲区储存由netG生成的图像 并可以用历史来更新鉴别器netD, 而不是由生成器 netG直接生成
        • 注意里面还有个query()方法
      • 3个损失函数

      • criterionGAN = GANLoss() 根据设置选择的,并带一个判断是真值real还是生成值fake,分别计算

      • criterionCycle = L1Loss()

      • criterionIdt = L1Loss()

      • 2个优化器

      • optimizer_G Adam() 使用chain把 netG_A 和 netG_B 的参数 混在一起 建议结合这里观看调试器里的变量

      • optimizer_D Adam() 同上, 不过是netD_A 和 netD_B

      • 然后分别把这俩优化器都放到optimizer里

    • 打印相关信息直接返回

在这里插入图片描述

  • model.setup(opt) 加载和打印网络,并设置学习率衰减策略schedule
    • networks.get_scheduler(optimizer,opt)
    • print_networks() 字面意思
  • visualizer = Visualizer(opt) 调用visdom和创建页面数据,之前没开server的这里就该出问题了
    • 里面在创建文件夹和log文件
  • visualizer.reset() 字面意思 让self.saved = False 下次能更新
  • for i, data in enumerate(dataset) 会执行dataset的__iter__()
    • 这个data给出4*batchsize的数组
      • A A图数组
      • B B图数组
      • A_path A图对应路径
      • B_path B图对应路径
  • model.set_input(data) 设置哪面往哪面走
    • real_A
    • real_B
    • image_paths 对应网络输入图片的路径
  • model.optimize_parameters() 计算损失函数,获取梯度,更新网络权重
    • 首先是生成器的

    • self.forward() 没啥好说的
      • self.fake_B = self.netG_A(self.real_A) # G_A(A) → B
      • self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A)) → A
      • self.fake_A = self.netG_B(self.real_B) # G_B(B) → A
      • self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B)) → B
    • self.set_requires_grad([self.netD_A, self.netD_B], False) 冻结两个netD的梯度
    • self.optimizer_G.zero_grad() netG的梯度清零
    • self.backward_G() 反传计算loss
      • 这里先计算了一下
        self.idt_A = self.netG_A(self.real_B) # G_A(B) → B
        self.idt_A = self.netG_A(self.real_B) # G_B(A) → A
      • loss_idt_A 是鉴别器A的测试损失 ||G_A(B) – B||
      • loss_idt_B 是鉴别器B的测试损失 ||G_B(A) – A||
      • loss_G_A 使用criterionGAN() 把 D_A(G_A(A)) 和opt扔进去, 生成的(1,1,30,30)和全1的(1,1,30,30)求MSEloss
      • loss_G_B 使用criterionGAN() 把 D_B(G_B(B)) 和opt扔进去, 生成的(1,1,30,30)和全1的(1,1,30,30)求MSEloss
      • loss_cycle_A 使用criterionCycle 说白了就是使用L1loss,即 || G_B(G_A(A)) – A||
      • loss_cycle_B 使用criterionCycle 说白了就是使用L1loss,即 || G_A(G_B(B)) – B||
      • loss_G = 上面六个loss求和,然后反向传播
    • self.optimizer_G.step() 更新两个G的权重
  • 然后是鉴别器的

  • self.set_requires_grad([self.netD_A, self.netD_B], True) 解冻两个netD的梯度
  • self.optimizer_D.zero_grad() netD的梯度清零
  • self.backward_D_A()
    • fake_B_pool.query(fake_B)取刚才G_A(B) → B的生成图像
    • 设置 netD real fake 调用backward_D_basic(D_A, realB, fakeB)
      • netD(real) 产生 (1,1,30,30)的输出
      • loss_D_real 输出和(1,1,30,30)的全1求MSEloss
      • netD(fake) 产生 (1,1,30,30)的输出
      • loss_D_fake 输出和(1,1,30,30)的全0求MSEloss
      • loss_D = (loss_D_real + loss_D_fake) * 0.5 然后反向传播
  • self.backward_D_B()
    • fake_A_pool.query(fake_A)取刚才G_B(A) → A的生成图像
    • 设置 netD real fake 调用backward_D_basic(D_B, realA, fakeA)
      • netD(real) 产生 (1,1,30,30)的输出
      • loss_D_real 输出和(1,1,30,30)的全1求MSEloss
      • netD(fake) 产生 (1,1,30,30)的输出
      • loss_D_fake 输出和(1,1,30,30)的全0求MSEloss
      • loss_D = (loss_D_real + loss_D_fake) * 0.5 然后反向传播
  • self.optimizer_D.step() 更新两个D的权重,和G不同的一点是,DA和DB的loss大小是不一样的

后面这些,都是次要的

  • #display images on visdom and save images to a HTML file
    • visualizer.display_current_results()
  • #print training losses and save logging information to the disk
  • #cache our latest model every iterations
  • #cache our model every epochs
  • model.update_learning_rate() # 在每个epoch后根据学习率衰减策略更新学习率.

然后说一些前面可能描述不是特别详细的,PS:记住一点,D的目的是能区分出来这是生成的,G的目的是让生成的假图片和真的差不多,让D看不出来
在这里插入图片描述

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

啊菜来了

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值