LADN的12Discriminator

网络结构

class Dis_pair(nn.Module):
    def __init__(self, input_dim_a=3, input_dim_b=3, dis_n_layer=5, norm='None', sn=True):
        super(Dis_pair, self).__init__()
        ch = 64
        self.model = self._make_net(ch, input_dim_a+input_dim_b, dis_n_layer, norm, sn)

    def _make_net(self, ch, input_dim, n_layer, norm, sn):
        model = []
        model += [LeakyReLUConv2d(input_dim, ch, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)]
        tch = ch
        for i in range(1, n_layer):
            model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)]
            tch *= 2
        model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm='None', sn=sn)]
        tch *= 2
        if sn:
            model += [nn.utils.spectral_norm(nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0))]
        else:
            model += [nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0)]
        return nn.Sequential(*model)

    def forward(self, image_a, image_b):
        out = torch.cat((image_a, image_b), 1)
        out = self.model(out)
        out = out.view(-1)
        outs = []
        outs.append(out)
        return outs
class LeakyReLUConv2d(nn.Module):
    def __init__(self, inplanes, outplanes, kernel_size, stride, padding=0, norm='None', sn=False):
        super(LeakyReLUConv2d, self).__init__()
        model = []
        model += [nn.ReflectionPad2d(padding)]
        if sn:
            model += [nn.utils.spectral_norm(nn.Conv2d(inplanes, outplanes, kernel_size=kernel_size, stride=stride, padding=0, bias=True))]
        else:
            model += [nn.Conv2d(inplanes, outplanes, kernel_size=kernel_size, stride=stride, padding=0, bias=True)]
        if norm == 'Instance':
            model += [nn.InstanceNorm2d(outplanes, affine=False)]
        model += [nn.LeakyReLU(inplace=True)]
        self.model = nn.Sequential(*model)
    def forward(self, x):
        return self.model(x)

定义网络

self.local_parts = ['eye', 'eye_', 'mouth', 'nose', 'cheek', 'cheek_', 'eyebrow', 'eyebrow_', 'uppernose', 'forehead', 'sidemouth', 'sidemouth_']
for i in range(self.n_local):
    local_part = self.local_parts[i]
    if '_' in local_part:
        continue
    setattr(self, 'dis'+local_part.capitalize(), init_net(networks.Dis_pair(opts.input_dim_a, opts.input_dim_b, opts.dis_n_layer), opts.backup_gpu, init_type='normal', gain=0.02))
#str.capitalize():第一个字母变大写,其他字母变小写

初始化网络

def init_net(net, gpu, init_type='normal', gain=0.02):
    assert(torch.cuda.is_available())
    net.to(gpu)
    init_weights(net, init_type, gain)
    return net
from torch.nn import init
def init_weights(net, init_type, gain):
    def init_func(m):
        classname = m.__class__.__name__
        if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
            if init_type == 'normal':
                init.normal_(m.weight.data, 0.0, gain)
            elif init_type == 'xavier':
                init.xavier_normal_(m.weight.data, gain=gain)
            elif init_type == 'kaiming':
                init.kaiming_normal_(m.weight.data, a=0, mode='fainplanes')
            elif init_type == 'orthogonal':
                init.orthogonal_(m.weight.data, gain=gain)
            else:
                raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
            if hasattr(m, 'bias') and m.bias is not None:
                init.constant_(m.bias.data, 0.0)
        elif classname.find('BatchNorm2d') != -1:
            init.normal_(m.weight.data, 1.0, gain)
            init.constant_(m.bias.data, 0.0)
    print('initialize network with %s' % init_type)
    net.apply(init_func)

优化网络

for i in range(self.n_local):
    local_part = self.local_parts[i]
       if '_' in local_part:
           continue
        setattr(self, 'dis'+local_part.capitalize()+'_opt', torch.optim.Adam(getattr(self, 'dis'+local_part.capitalize()).parameters(), lr=lr_dlocal_style, betas=(0.5, 0.999), weight_decay=0.0001))

训练网络

model.update_D_local_style(data)

def update_D_local_style(self, data):
    self.input_A = data['img_A'].to(self.backup_device).detach()
    self.input_B = data['img_B'].to(self.backup_device).detach()
    self.input_C = data['img_C'].to(self.backup_device).detach()
    self.rects_A = data['rects_A'].to(self.backup_device).detach()
    self.rects_B = data['rects_B'].to(self.backup_device).detach()
    self.rects_C = data['rects_C'].to(self.backup_device).detach()

    self.forward_local_style()

    for i in range(self.n_local):
        local_part = self.local_parts[i]
        if '_' not in local_part:
            getattr(self, 'dis'+local_part.capitalize()+'_opt').zero_grad()
            loss_D_Style = self.backward_local_styleD(getattr(self, 'dis'+local_part.capitalize()), self.rects_transfer_encoded[:,i,:], self.rects_after_encoded[:,i,:], self.rects_blend_encoded[:,i,:], name=local_part)
            nn.utils.clip_grad_norm_(getattr(self, 'dis'+local_part.capitalize()).parameters(), 5)
            getattr(self, 'dis'+local_part.capitalize()+'_opt').step()
            setattr(self, 'dis'+local_part.capitalize()+'Style_loss', loss_D_Style.item())
        else:
            local_part = local_part.split('_')[0]
            getattr(self, 'dis'+local_part.capitalize()+'_opt').zero_grad()
            loss_D_Style_ = self.backward_local_styleD(getattr(self, 'dis'+local_part.capitalize()), self.rects_transfer_encoded[:,i,:], self.rects_after_encoded[:,i,:], self.rects_blend_encoded[:,i,:], name=local_part+'2', flip=True)
            nn.utils.clip_grad_norm_(getattr(self, 'dis'+local_part.capitalize()).parameters(), 5)
            getattr(self, 'dis'+local_part.capitalize()+'_opt').step()
            loss_D_Style = getattr(self, 'dis'+local_part.capitalize()+'Style_loss')
            setattr(self, 'dis'+local_part.capitalize()+'Style_loss', loss_D_Style+loss_D_Style_.item())

def forward_local_style(self):
    self.forward_style()
    half_size = self.batch_size//2
    self.rects_transfer_encoded = self.rects_A[0:half_size]
    self.rects_after_encoded = self.rects_B[0:half_size]
    self.rects_blend_encoded = self.rects_C[0:half_size]

def forward_style(self):
    half_size = self.batch_size//2
    self.real_A_encoded = self.input_A[0:half_size]
    self.real_B_encoded = self.input_B[0:half_size]
    self.real_C_encoded = self.input_C[0:half_size]

    # get encoded z_c
    self.z_content_a = self.enc_c.forward_a(self.real_A_encoded)
    self.z_content_a = (self.z_content_a[0].to(self.device), self.z_content_a[1].to(self.device))

    # get encoded z_a
    self.z_attr_b = self.enc_a.forward_b(self.real_B_encoded.to(self.backup_device))
    self.z_attr_b = self.z_attr_b.to(self.device)

    # first cross translation
    self.fake_B_encoded = self.gen.forward_b(*self.z_content_a, self.z_attr_b)

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值