GAN生成对抗网络

GAN生成对抗网络

GAN的训练可分为3个阶段(1)generator生成器的训练(2)discriminator判别器的训练(3)任务网络的训练

#####################   Update netG
                for p in self.netD.parameters():
                    p.requires_grad = False
                for p in self.model.parameters():
                    p.requires_grad = False
                self.netG_optimizer.zero_grad()
                
                syn2real = self.netG(rgbs)      # synthetic rgb to real rgb
                real2real = self.netG(real_rgbs)

                Gd_syn2real = self.netD(syn2real)
                Gd_real2real = self.netD(real2real)

                Gd_loss = torch.mean((Gd_syn2real-1.0)**2) + torch.mean((Gd_real2real-1.0)**2)
                renconstruction_loss = 40* (nn.MSELoss()(real2real,real_rgbs)+nn.MSELoss()(syn2real,rgbs))

                g_loss =  renconstruction_loss +  Gd_loss
                
                g_loss.backward()
                self.netG_optimizer.step()

                ################################### schedule #######################################
                warmup_iters = 500
                warmup_ratio = 1e-06
                if self.total_iter_num < warmup_iters:
                    k = (1 - self.total_iter_num / warmup_iters) * (1 - warmup_ratio)
                    lr_ = self.base_lr * (1 - k)
                else:
                    lr_ = self.base_lr * (1.0 - self.total_iter_num / self.max_iterations) ** 1.0
                for param_group in self.netG_optimizer.param_groups:
                    param_group['lr'] = lr_
                #####################   Update netD
                for p in self.netD.parameters():
                    p.requires_grad = True
                for p in self.model.parameters():
                    p.requires_grad = False
                self.netD_optimizer.zero_grad()

                d_syn2real = self.netD(syn2real.detach())
                d_real2real = self.netD(real2real.detach())

                d_loss = torch.mean((d_syn2real-0.)**2) + torch.mean((d_real2real-1.0)**2)
                d_loss.backward()
                self.netD_optimizer.step()

                for param_group in self.netD_optimizer.param_groups:
                    param_group['lr'] = lr_
                
                if self.total_iter_num%10==0:
                    self.writer.add_image('Input/sim_rgb',rgbs[0].detach().cpu(),self.total_iter_num)
                    self.writer.add_image('Input/real_rgb',real_rgbs[0].detach().cpu(),self.total_iter_num)
                    self.writer.add_image('output/syn2real',syn2real[0].detach().cpu(),self.total_iter_num)   
                    self.writer.add_image('output/real2real',real2real[0].detach().cpu(),self.total_iter_num) 

                    self.writer.add_scalar('syn2real_loss', renconstruction_loss,self.total_iter_num)
                    self.writer.add_scalar('Gd_loss', Gd_loss,self.total_iter_num)

                    self.writer.add_scalar('GAN_g_loss', g_loss,self.total_iter_num)
                    self.writer.add_scalar('GAN_d_loss', d_loss,self.total_iter_num)  

                #####################   Update netT
                ##################### task code

GD_net.py

import torchvision.models as models
import torch.nn as nn
import torch
import torch.nn.functional as torch_nn_func
import math

from collections import namedtuple
class atrous_conv(nn.Sequential):
    def __init__(self, in_channels, out_channels, dilation, apply_bn_first=True):
        super(atrous_conv, self).__init__()
        self.atrous_conv = torch.nn.Sequential()
        if apply_bn_first:
            self.atrous_conv.add_module('first_bn', nn.BatchNorm2d(in_channels, momentum=0.01, affine=True, track_running_stats=True, eps=1.1e-5))
        
        self.atrous_conv.add_module('aconv_sequence', nn.Sequential(nn.ReLU(),
                                                                    nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, bias=False, kernel_size=1, stride=1, padding=0),
                                                                    nn.BatchNorm2d(out_channels*2, momentum=0.01, affine=True, track_running_stats=True),
                                                                    nn.ReLU(),
                                                                    nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, bias=False, kernel_size=3, stride=1,
                                                                              padding=(dilation, dilation), dilation=dilation)))

    def forward(self, x):
        return self.atrous_conv.forward(x)
    

class upconv(nn.Module):
    def __init__(self, in_channels, out_channels, ratio=2):
        super(upconv, self).__init__()
        self.elu = nn.ELU()
        self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, bias=False, kernel_size=3, stride=1, padding=1)
        self.ratio = ratio
        
    def forward(self, x):
        up_x = torch_nn_func.interpolate(x, scale_factor=self.ratio, mode='nearest')
        out = self.conv(up_x)
        out = self.elu(out)
        return out

class Decoder(nn.Module):
    def __init__(self, feat_out_channels, num_features=512):
        super(Decoder, self).__init__()

        self.upconv5    = upconv(feat_out_channels[4], num_features)
        self.bn5        = nn.BatchNorm2d(num_features, momentum=0.01, affine=True, eps=1.1e-5)
        
        self.conv5      = torch.nn.Sequential(nn.Conv2d(num_features + feat_out_channels[3], num_features, 3, 1, 1, bias=False),
                                              nn.ELU())
        self.upconv4    = upconv(num_features, num_features // 2)
        self.bn4        = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True, eps=1.1e-5)
        self.conv4      = torch.nn.Sequential(nn.Conv2d(num_features // 2 + feat_out_channels[2], num_features // 2, 3, 1, 1, bias=False),
                                              nn.ELU())
        self.bn4_2      = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True, eps=1.1e-5)
        
        self.daspp_3    = atrous_conv(num_features // 2, num_features // 4, 3, apply_bn_first=False)
        self.daspp_6    = atrous_conv(num_features // 2 + num_features // 4 + feat_out_channels[2], num_features // 4, 6)
        self.daspp_12   = atrous_conv(num_features + feat_out_channels[2], num_features // 4, 12)
        self.daspp_18   = atrous_conv(num_features + num_features // 4 + feat_out_channels[2], num_features // 4, 18)
        self.daspp_24   = atrous_conv(num_features + num_features // 2 + feat_out_channels[2], num_features // 4, 24)
        self.daspp_conv = torch.nn.Sequential(nn.Conv2d(num_features + num_features // 2 + num_features // 4, num_features // 4, 3, 1, 1, bias=False),
                                              nn.ELU())
        
        self.upconv3    = upconv(num_features // 4, num_features // 4)
        self.bn3        = nn.BatchNorm2d(num_features // 4, momentum=0.01, affine=True, eps=1.1e-5)
        self.conv3      = torch.nn.Sequential(nn.Conv2d(num_features // 4 + feat_out_channels[1], num_features // 4, 3, 1, 1, bias=False),
                                              nn.ELU())
        
        self.upconv2    = upconv(num_features // 4, num_features // 8)
        self.bn2        = nn.BatchNorm2d(num_features // 8, momentum=0.01, affine=True, eps=1.1e-5)
        self.conv2      = torch.nn.Sequential(nn.Conv2d(num_features // 8 + feat_out_channels[0], num_features // 8, 3, 1, 1, bias=False),
                                              nn.ELU())
        
        self.upconv1    = upconv(num_features // 8, num_features // 16)
        self.bn1        = nn.BatchNorm2d(num_features // 16, momentum=0.01, affine=True, eps=1.1e-5)
        self.conv1      = torch.nn.Sequential(nn.Conv2d(num_features // 16, num_features // 16, 3, 1, 1, bias=False),
                                              nn.ELU())
        self.conv_out = torch.nn.Sequential(nn.Conv2d(num_features // 16, 3, 3, 1, 1, bias=False),
                                              nn.Sigmoid())
        self.conv_p  =  torch.nn.Sequential(nn.Conv2d(num_features // 16, 3, 3, 1, 1, bias=False),
                                              nn.Sigmoid())                                    
            

    def forward(self, features):
        skip0, skip1, skip2, skip3 = features[0], features[1], features[2], features[3]
        dense_features = torch.nn.ReLU()(features[4])
        upconv5 = self.upconv5(dense_features) # H/16
        upconv5 = self.bn5(upconv5)
        if upconv5.shape[2:]!=skip3.shape[2:]:
            upconv5 = torch_nn_func.interpolate(upconv5, skip3.shape[2:], mode='nearest')
        concat5 = torch.cat([upconv5, skip3], dim=1)
        iconv5 = self.conv5(concat5)
        
        upconv4 = self.upconv4(iconv5) # H/8
        upconv4 = self.bn4(upconv4)
        if upconv4.shape[2:]!=skip2.shape[2:]:
            upconv4 = torch_nn_func.interpolate(upconv4, skip2.shape[2:], mode='nearest')
        concat4 = torch.cat([upconv4, skip2], dim=1)
        iconv4 = self.conv4(concat4)
        iconv4 = self.bn4_2(iconv4)

        daspp_3 = self.daspp_3(iconv4)
        concat4_2 = torch.cat([concat4, daspp_3], dim=1)
        daspp_6 = self.daspp_6(concat4_2)
        concat4_3 = torch.cat([concat4_2, daspp_6], dim=1)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值