网络结构
class Dis_pair(nn.Module):
def __init__(self, input_dim_a=3, input_dim_b=3, dis_n_layer=5, norm='None', sn=True):
super(Dis_pair, self).__init__()
ch = 64
self.model = self._make_net(ch, input_dim_a+input_dim_b, dis_n_layer, norm, sn)
def _make_net(self, ch, input_dim, n_layer, norm, sn):
model = []
model += [LeakyReLUConv2d(input_dim, ch, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)]
tch = ch
for i in range(1, n_layer):
model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm=norm, sn=sn)]
tch *= 2
model += [LeakyReLUConv2d(tch, tch * 2, kernel_size=3, stride=2, padding=1, norm='None', sn=sn)]
tch *= 2
if sn:
model += [nn.utils.spectral_norm(nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0))]
else:
model += [nn.Conv2d(tch, 1, kernel_size=1, stride=1, padding=0)]
return nn.Sequential(*model)
def forward(self, image_a, image_b):
out = torch.cat((image_a, image_b), 1)
out = self.model(out)
out = out.view(-1)
outs = []
outs.append(out)
return outs
class LeakyReLUConv2d(nn.Module):
def __init__(self, inplanes, outplanes, kernel_size, stride, padding=0, norm='None', sn=False):
super(LeakyReLUConv2d, self).__init__()
model = []
model += [nn.ReflectionPad2d(padding)]
if sn:
model += [nn.utils.spectral_norm(nn.Conv2d(inplanes, outplanes, kernel_size=kernel_size, stride=stride, padding=0, bias=True))]
else:
model += [nn.Conv2d(inplanes, outplanes, kernel_size=kernel_size, stride=stride, padding=0, bias=True)]
if norm == 'Instance':
model += [nn.InstanceNorm2d(outplanes, affine=False)]
model += [nn.LeakyReLU(inplace=True)]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
定义网络
self.local_parts = ['eye', 'eye_', 'mouth', 'nose', 'cheek', 'cheek_', 'eyebrow', 'eyebrow_', 'uppernose', 'forehead', 'sidemouth', 'sidemouth_']
for i in range(self.n_local):
local_part = self.local_parts[i]
if '_' in local_part:
continue
setattr(self, 'dis'+local_part.capitalize(), init_net(networks.Dis_pair(opts.input_dim_a, opts.input_dim_b, opts.dis_n_layer), opts.backup_gpu, init_type='normal', gain=0.02))
#str.capitalize():第一个字母变大写,其他字母变小写
初始化网络
def init_net(net, gpu, init_type='normal', gain=0.02):
assert(torch.cuda.is_available())
net.to(gpu)
init_weights(net, init_type, gain)
return net
from torch.nn import init
def init_weights(net, init_type, gain):
def init_func(m):
classname = m.__class__.__name__
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
if init_type == 'normal':
init.normal_(m.weight.data, 0.0, gain)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=gain)
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fainplanes')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=gain)
else:
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
elif classname.find('BatchNorm2d') != -1:
init.normal_(m.weight.data, 1.0, gain)
init.constant_(m.bias.data, 0.0)
print('initialize network with %s' % init_type)
net.apply(init_func)
优化网络
for i in range(self.n_local):
local_part = self.local_parts[i]
if '_' in local_part:
continue
setattr(self, 'dis'+local_part.capitalize()+'_opt', torch.optim.Adam(getattr(self, 'dis'+local_part.capitalize()).parameters(), lr=lr_dlocal_style, betas=(0.5, 0.999), weight_decay=0.0001))
训练网络
model.update_D_local_style(data)
def update_D_local_style(self, data):
self.input_A = data['img_A'].to(self.backup_device).detach()
self.input_B = data['img_B'].to(self.backup_device).detach()
self.input_C = data['img_C'].to(self.backup_device).detach()
self.rects_A = data['rects_A'].to(self.backup_device).detach()
self.rects_B = data['rects_B'].to(self.backup_device).detach()
self.rects_C = data['rects_C'].to(self.backup_device).detach()
self.forward_local_style()
for i in range(self.n_local):
local_part = self.local_parts[i]
if '_' not in local_part:
getattr(self, 'dis'+local_part.capitalize()+'_opt').zero_grad()
loss_D_Style = self.backward_local_styleD(getattr(self, 'dis'+local_part.capitalize()), self.rects_transfer_encoded[:,i,:], self.rects_after_encoded[:,i,:], self.rects_blend_encoded[:,i,:], name=local_part)
nn.utils.clip_grad_norm_(getattr(self, 'dis'+local_part.capitalize()).parameters(), 5)
getattr(self, 'dis'+local_part.capitalize()+'_opt').step()
setattr(self, 'dis'+local_part.capitalize()+'Style_loss', loss_D_Style.item())
else:
local_part = local_part.split('_')[0]
getattr(self, 'dis'+local_part.capitalize()+'_opt').zero_grad()
loss_D_Style_ = self.backward_local_styleD(getattr(self, 'dis'+local_part.capitalize()), self.rects_transfer_encoded[:,i,:], self.rects_after_encoded[:,i,:], self.rects_blend_encoded[:,i,:], name=local_part+'2', flip=True)
nn.utils.clip_grad_norm_(getattr(self, 'dis'+local_part.capitalize()).parameters(), 5)
getattr(self, 'dis'+local_part.capitalize()+'_opt').step()
loss_D_Style = getattr(self, 'dis'+local_part.capitalize()+'Style_loss')
setattr(self, 'dis'+local_part.capitalize()+'Style_loss', loss_D_Style+loss_D_Style_.item())
def forward_local_style(self):
self.forward_style()
half_size = self.batch_size//2
self.rects_transfer_encoded = self.rects_A[0:half_size]
self.rects_after_encoded = self.rects_B[0:half_size]
self.rects_blend_encoded = self.rects_C[0:half_size]
def forward_style(self):
half_size = self.batch_size//2
self.real_A_encoded = self.input_A[0:half_size]
self.real_B_encoded = self.input_B[0:half_size]
self.real_C_encoded = self.input_C[0:half_size]
# get encoded z_c
self.z_content_a = self.enc_c.forward_a(self.real_A_encoded)
self.z_content_a = (self.z_content_a[0].to(self.device), self.z_content_a[1].to(self.device))
# get encoded z_a
self.z_attr_b = self.enc_a.forward_b(self.real_B_encoded.to(self.backup_device))
self.z_attr_b = self.z_attr_b.to(self.device)
# first cross translation
self.fake_B_encoded = self.gen.forward_b(*self.z_content_a, self.z_attr_b)