SRGAN loss部分的pytorch代码实现

转载地址:https://bbs.huaweicloud.com/forum/thread-137101-1-1.html

作者: 雨丝儿

最近在参加华为与高校合做开发mindspore模型的活动,使用mindspore开发了SRGAN模型,下面几篇帖子想针对SRGAN做一些自己的经验分享。这篇帖子分享SRAGAN loss pytorch的实现。

Pytorch版本参考:https://github.com/dongheehand/SRGAN-PyTorch

Paper中SRGAN的loss:

对于Discriminator:

就是基础GAN中Discriminator的loss

代码实现:

 

其中gt为原始高分辨率图像,lr为gt经过双三次插值缩小四倍的低分辨率图像,cross_ent为BCELoss()

对与Generator:

Generator的loss包含三部分,一是基础的MSELoss,二是adversarial loss,三是将生成的HR图像与原始高清分辨率图像分别经过预训练的vgg19提取特征后,计算MSELoss.

代码部分:

VGG_loss = perceptual_loss(vgg_net)

    cross_ent = nn.BCELoss()

    tv_loss = TVLoss()

    real_label = torch.ones((args.batch_size, 1)).to(device)

    fake_label = torch.zeros((args.batch_size, 1)).to(device)

    

        for i, tr_data in enumerate(loader):

            gt = tr_data['GT'].to(device)

            lr = tr_data['LR'].to(device)

                        

           

             output, _ = generator(lr)

            fake_prob = discriminator(output)

            

            # 第一部分

            L2_loss = l2_loss(output, gt)

             # 第二部分

            adversarial_loss = args.adv_coeff * cross_ent(fake_prob, real_label)

# 第三部分

_percep_loss, hr_feat, sr_feat = VGG_loss((gt + 1.0) / 2.0, (output + 1.0) / 2.0, layer = args.feat_layer)

percep_loss = args.vgg_rescale_coeff * _percep_loss

            g_loss = L2_loss + adversarial_loss  + percep_loss          

            g_optim.zero_grad()

            d_optim.zero_grad()

            g_loss.backward()

            g_optim.step()

          

其中vgg19是在imagenet上训练好的vgg19,选取其前37层,args.adv_coeff,args.vgg_rescale_coeff 为loss的系数,分别取0.003和0.006。

以上就是srgan loss部分的pytorch代码实现,下篇帖子将分享srgan loss部分minspore代码的实现。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值