SRGAN代码结构分析

  • data_utils.py

    1. 数据的基本处理方法定义,由torchvision.transforms来定义返回Compose对象
    2. 继承Dataset类,来定义train,test,val等数据的读取和处理方式
  • loss.py

    1. 数学公式的常规操作,矩阵运算,
    2. 然后写测试代码来运行验证
  • model.py

    1. 这部分是最好理解和编写的
    2. 先写好基本的res模块和upsample模块
    3. 然后用nn.Sequential串联各个模块
  • train.py

    1. 设置超参数
    2. 读取数据集然后用DataLoader实现batch
    3. 定义网络对象,统计其中参数的总数
    4. 定义优化器,传网络参数进去
    5. 训练,验证循环的编写(核心)
      • 更新鉴别器:
        (1)低分辨率作为噪声传入生成器得到fake_img,高分辨率作为real_img
        (2)鉴别器梯度归零,两图传入鉴别器,计算D_loss并backward回传梯度,然后调用optimizerD.step()更新鉴别器的速率
      • 更新生成器:
        (1)生成器梯度归零,计算G_loss并backward回传梯
  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值