GAN生成对抗网络
GAN的训练可分为3个阶段(1)generator生成器的训练(2)discriminator判别器的训练(3)任务网络的训练
##################### Update netG
for p in self.netD.parameters():
p.requires_grad = False
for p in self.model.parameters():
p.requires_grad = False
self.netG_optimizer.zero_grad()
syn2real = self.netG(rgbs) # synthetic rgb to real rgb
real2real = self.netG(real_rgbs)
Gd_syn2real = self.netD(syn2real)
Gd_real2real = self.netD(real2real)
Gd_loss = torch.mean((Gd_syn2real-1.0)**2) + torch.mean((Gd_real2real-1.0)**2)
renconstruction_loss = 40* (nn.MSELoss()(real2real,real_rgbs)+nn.MSELoss()(syn2real,rgbs))
g_loss = renconstruction_loss + Gd_loss
g_loss.backward()
self.netG_optimizer.step()
################################### schedule #######################################
warmup_iters = 500
warmup_ratio = 1e-06
if self.total_iter_num < warmup_iters:
k = (1 - self.total_iter_num / warmup_iters) * (1 - warmup_ratio)
lr_ = self.base_lr * (1 - k)
else:
lr_ = self.base_lr * (1.0 - self.total_iter_num / self.max_iterations) ** 1.0
for param_group in self.netG_optimizer.param_groups:
param_group['lr'] = lr_
##################### Update netD
for p in self.netD.parameters():
p.requires_grad = True
for p in self.model.parameters():
p.requires_grad = False
self.netD_optimizer.zero_grad()
d_syn2real = self.netD(syn2real.detach())
d_real2real = self.netD(real2real.detach())
d_loss = torch.mean((d_syn2real-0.)**2) + torch.mean((d_real2real-1.0)**2)
d_loss.backward()
self.netD_optimizer.step()
for param_group in self.netD_optimizer.param_groups:
param_group['lr'] = lr_
if self.total_iter_num%10==0:
self.writer.add_image('Input/sim_rgb',rgbs[0].detach().cpu(),self.total_iter_num)
self.writer.add_image('Input/real_rgb',real_rgbs[0].detach().cpu(),self.total_iter_num)
self.writer.add_image('output/syn2real',syn2real[0].detach().cpu(),self.total_iter_num)
self.writer.add_image('output/real2real',real2real[0].detach().cpu(),self.total_iter_num)
self.writer.add_scalar('syn2real_loss', renconstruction_loss,self.total_iter_num)
self.writer.add_scalar('Gd_loss', Gd_loss,self.total_iter_num)
self.writer.add_scalar('GAN_g_loss', g_loss,self.total_iter_num)
self.writer.add_scalar('GAN_d_loss', d_loss,self.total_iter_num)
##################### Update netT
##################### task code
GD_net.py
import torchvision.models as models
import torch.nn as nn
import torch
import torch.nn.functional as torch_nn_func
import math
from collections import namedtuple
class atrous_conv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation, apply_bn_first=True):
super(atrous_conv, self).__init__()
self.atrous_conv = torch.nn.Sequential()
if apply_bn_first:
self.atrous_conv.add_module('first_bn', nn.BatchNorm2d(in_channels, momentum=0.01, affine=True, track_running_stats=True, eps=1.1e-5))
self.atrous_conv.add_module('aconv_sequence', nn.Sequential(nn.ReLU(),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels*2, bias=False, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(out_channels*2, momentum=0.01, affine=True, track_running_stats=True),
nn.ReLU(),
nn.Conv2d(in_channels=out_channels * 2, out_channels=out_channels, bias=False, kernel_size=3, stride=1,
padding=(dilation, dilation), dilation=dilation)))
def forward(self, x):
return self.atrous_conv.forward(x)
class upconv(nn.Module):
def __init__(self, in_channels, out_channels, ratio=2):
super(upconv, self).__init__()
self.elu = nn.ELU()
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, bias=False, kernel_size=3, stride=1, padding=1)
self.ratio = ratio
def forward(self, x):
up_x = torch_nn_func.interpolate(x, scale_factor=self.ratio, mode='nearest')
out = self.conv(up_x)
out = self.elu(out)
return out
class Decoder(nn.Module):
def __init__(self, feat_out_channels, num_features=512):
super(Decoder, self).__init__()
self.upconv5 = upconv(feat_out_channels[4], num_features)
self.bn5 = nn.BatchNorm2d(num_features, momentum=0.01, affine=True, eps=1.1e-5)
self.conv5 = torch.nn.Sequential(nn.Conv2d(num_features + feat_out_channels[3], num_features, 3, 1, 1, bias=False),
nn.ELU())
self.upconv4 = upconv(num_features, num_features // 2)
self.bn4 = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True, eps=1.1e-5)
self.conv4 = torch.nn.Sequential(nn.Conv2d(num_features // 2 + feat_out_channels[2], num_features // 2, 3, 1, 1, bias=False),
nn.ELU())
self.bn4_2 = nn.BatchNorm2d(num_features // 2, momentum=0.01, affine=True, eps=1.1e-5)
self.daspp_3 = atrous_conv(num_features // 2, num_features // 4, 3, apply_bn_first=False)
self.daspp_6 = atrous_conv(num_features // 2 + num_features // 4 + feat_out_channels[2], num_features // 4, 6)
self.daspp_12 = atrous_conv(num_features + feat_out_channels[2], num_features // 4, 12)
self.daspp_18 = atrous_conv(num_features + num_features // 4 + feat_out_channels[2], num_features // 4, 18)
self.daspp_24 = atrous_conv(num_features + num_features // 2 + feat_out_channels[2], num_features // 4, 24)
self.daspp_conv = torch.nn.Sequential(nn.Conv2d(num_features + num_features // 2 + num_features // 4, num_features // 4, 3, 1, 1, bias=False),
nn.ELU())
self.upconv3 = upconv(num_features // 4, num_features // 4)
self.bn3 = nn.BatchNorm2d(num_features // 4, momentum=0.01, affine=True, eps=1.1e-5)
self.conv3 = torch.nn.Sequential(nn.Conv2d(num_features // 4 + feat_out_channels[1], num_features // 4, 3, 1, 1, bias=False),
nn.ELU())
self.upconv2 = upconv(num_features // 4, num_features // 8)
self.bn2 = nn.BatchNorm2d(num_features // 8, momentum=0.01, affine=True, eps=1.1e-5)
self.conv2 = torch.nn.Sequential(nn.Conv2d(num_features // 8 + feat_out_channels[0], num_features // 8, 3, 1, 1, bias=False),
nn.ELU())
self.upconv1 = upconv(num_features // 8, num_features // 16)
self.bn1 = nn.BatchNorm2d(num_features // 16, momentum=0.01, affine=True, eps=1.1e-5)
self.conv1 = torch.nn.Sequential(nn.Conv2d(num_features // 16, num_features // 16, 3, 1, 1, bias=False),
nn.ELU())
self.conv_out = torch.nn.Sequential(nn.Conv2d(num_features // 16, 3, 3, 1, 1, bias=False),
nn.Sigmoid())
self.conv_p = torch.nn.Sequential(nn.Conv2d(num_features // 16, 3, 3, 1, 1, bias=False),
nn.Sigmoid())
def forward(self, features):
skip0, skip1, skip2, skip3 = features[0], features[1], features[2], features[3]
dense_features = torch.nn.ReLU()(features[4])
upconv5 = self.upconv5(dense_features) # H/16
upconv5 = self.bn5(upconv5)
if upconv5.shape[2:]!=skip3.shape[2:]:
upconv5 = torch_nn_func.interpolate(upconv5, skip3.shape[2:], mode='nearest')
concat5 = torch.cat([upconv5, skip3], dim=1)
iconv5 = self.conv5(concat5)
upconv4 = self.upconv4(iconv5) # H/8
upconv4 = self.bn4(upconv4)
if upconv4.shape[2:]!=skip2.shape[2:]:
upconv4 = torch_nn_func.interpolate(upconv4, skip2.shape[2:], mode='nearest')
concat4 = torch.cat([upconv4, skip2], dim=1)
iconv4 = self.conv4(concat4)
iconv4 = self.bn4_2(iconv4)
daspp_3 = self.daspp_3(iconv4)
concat4_2 = torch.cat([concat4, daspp_3], dim=1)
daspp_6 = self.daspp_6(concat4_2)
concat4_3 = torch.cat([concat4_2, daspp_6], dim=1)