-
data_utils.py
- 数据的基本处理方法定义,由torchvision.transforms来定义返回Compose对象
- 继承Dataset类,来定义train,test,val等数据的读取和处理方式
-
- 数学公式的常规操作,矩阵运算,
- 然后写测试代码来运行验证
-
- 这部分是最好理解和编写的
- 先写好基本的res模块和upsample模块
- 然后用nn.Sequential串联各个模块
-
- 设置超参数
- 读取数据集然后用DataLoader实现batch
- 定义网络对象,统计其中参数的总数
- 定义优化器,传网络参数进去
- 训练,验证循环的编写(核心)
- 更新鉴别器:
(1)低分辨率作为噪声传入生成器得到fake_img,高分辨率作为real_img
(2)鉴别器梯度归零,两图传入鉴别器,计算D_loss并backward回传梯度,然后调用optimizerD.step()更新鉴别器的速率 - 更新生成器:
(1)生成器梯度归零,计算G_loss并backward回传梯
- 更新鉴别器:
SRGAN代码结构分析
最新推荐文章于 2024-06-03 19:42:55 发布