《Guided Image-to-Image Translation with Bi-Directional Feature Transformation》
train.py
model.set_input(data)
model.optimize_parameters()
开始训练
models/guided_pix2pix_model.py
def set_input(self, input):
self.real_A = input['A'].to(self.device)
self.real_B = input['B'].to(self.device)
self.guide = input['guide'].to(self.device)
这里的guide是指的是引导的image/pose,real_A指的是作为输入的image,real_B指的是GT
def forward(self):
self.fake_B = self.netG(self.real_A, self.guide)
# load/define networks
self.netG = networks.define_G(input_nc=opt.input_nc, guide_nc=opt.guide_nc, output_nc=opt.output_nc, ngf=opt.ngf, netG=opt.netG, n_layers=opt.n_layers, norm=opt.norm, init_type=opt.init_type, init_gain=opt.init_gain, gpu_ids=self.gpu_ids)
models/networks.py
def define_G(input_nc, guide_nc, output_nc, ngf, netG, n_layers=8, n_downsampling=3, n_blocks=9, norm='batch', init_type='normal', init_gain=0.02, gpu_ids=[]):
net = None
norm_layer = get_norm_layer(norm_type=norm)
if netG == 'bFT_resnet':
net = bFT_Resnet(input_nc, guide_nc, output_nc, ngf, norm_layer=norm_layer, n_blocks=n_blocks)
elif netG == 'bFT_unet':
net = bFT_Unet(input_nc, guide_nc, output_nc, n_layers, ngf, norm_layer=norm_layer)
else:
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
net = init_net(net, init_type, init_gain, gpu_ids)
return net
看一下bFT_resent
class bFT_Resnet(nn.Module):
def __init__(self, input_nc, guide_nc, output_nc, ngf=64, n_blocks=9, norm_layer=nn.BatchNorm2d,
padding_type='reflect', bottleneck_depth=100):
super(bFT_Resnet, self).__init__()
self.activation = nn.ReLU(True)
n_downsampling=3
## input
padding_in = [nn.ReflectionPad2d(3), nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0)]
self.padding_in = nn.Sequential(*padding_in)
self.conv1 = nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1)
self.conv2 = nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1)
self.conv3 = nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, stride=2, padding=1)
## guide
padding_g = [nn.ReflectionPad2d(3), nn.Conv2d(guide_nc, ngf, kernel_size=7, padding=0)]
self.padding_g = nn.Sequential(*padding_g)
self.conv1_g = nn.Conv2d(ngf, ngf * 2, kernel_size=3, stride=2, padding=1)
self.conv2_g = nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, stride=2, padding=1)
self.conv3_g = nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, stride=2, padding=1)
# bottleneck1
self.bottleneck_alpha_1 = self.bottleneck_layer(ngf, bottleneck_depth)
self.G_bottleneck_alpha_1 = self.bottleneck_layer(ngf, bottleneck_depth)
self.bottleneck_beta_1 = self.bottleneck_layer(ngf, bottleneck_depth)
self.G_bottleneck_beta_1 = self.bottleneck_layer(ngf, bottleneck_depth)
# bottleneck2
self.bottleneck_alpha_2 = self.bottleneck_layer(ngf*2, bottleneck_depth)
self.G_bottleneck_alpha_2 = self.bottleneck_layer(ngf*2, bottleneck_depth)
self.bottleneck_beta_2 = self.bottleneck_layer(ngf*2, bottleneck_depth)
self.G_bottleneck_beta_2 = self.bottleneck_layer(ngf*2, bottleneck_depth)
# bottleneck3
self.bottleneck_alpha_3 = self.bottleneck_layer(ngf*4, bottleneck_depth)
self.G_bottleneck_alpha_3 = self.bottleneck_layer(ngf*4, bottleneck_depth)
self.bottleneck_beta_3 = self.bottleneck_layer(ngf*4, bottleneck_depth)
self.G_bottleneck_beta_3 = self.bottleneck_layer(ngf*4, bottleneck_depth)
# bottleneck4
self.bottleneck_alpha_4 = self.bottleneck_layer(ngf*8, bottleneck_depth)
self.G_bottleneck_alpha_4 = self.bottleneck_layer(ngf*8, bottleneck_depth)
self.bottleneck_beta_4 = self.bottleneck_layer(ngf*8, bottleneck_depth)
self.G_bottleneck_beta_4 = self.bottleneck_layer(ngf*8, bottleneck_depth)
### 这些bottlenect_layer都是由1x1的卷积,激活层,1x1的卷积组成的,做从nc->nc的映射
resnet = []
mult = 2**n_downsampling
for i in range(n_blocks):
resnet += [ResnetBlock(ngf * mult, padding_type=padding_type, activation=self.activation, norm_layer=norm_layer)]
self.resnet = nn.Sequential(*resnet)
decoder = []
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
decoder += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=2, padding=1, output_padding=1),
norm_layer(int(ngf * mult / 2)), self.activation]
self.pre_decoder = nn.Sequential(*decoder)
self.decoder = nn.Sequential(*[nn.ReflectionPad2d(3), nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0), nn.Tanh()])
def bottleneck_layer(self, nc, bottleneck_depth):
return nn.Sequential(*[nn.Conv2d(nc, bottleneck_depth, kernel_size=1), self.activation, nn.Conv2d(bottleneck_depth, nc, kernel_size=1)])
def get_FiLM_param_(self, X, i, guide=False):
x = X.clone()
# bottleneck
if guide:
if (i==1):
alpha_layer = self.G_bottleneck_alpha_1
beta_layer = self.G_bottleneck_beta_1
elif (i==2):
alpha_layer = self.G_bottleneck_alpha_2
beta_layer = self.G_bottleneck_beta_2
elif (i==3):
alpha_layer = self.G_bottleneck_alpha_3
beta_layer = self.G_bottleneck_beta_3
elif (i==4):
alpha_layer = self.G_bottleneck_alpha_4
beta_layer = self.G_bottleneck_beta_4
else:
if (i==1):
alpha_layer = self.bottleneck_alpha_1
beta_layer = self.bottleneck_beta_1
elif (i==2):
alpha_layer = self.bottleneck_alpha_2
beta_layer = self.bottleneck_beta_2
elif (i==3):
alpha_layer = self.bottleneck_alpha_3
beta_layer = self.bottleneck_beta_3
elif (i==4):
alpha_layer = self.bottleneck_alpha_4
beta_layer = self.bottleneck_beta_4
alpha = alpha_layer(x)
beta = beta_layer(x)
return alpha, beta
def forward(self, input, guidance):
input = self.padding_in(input)
guidance = self.padding_g(guidance)
g_alpha1, g_beta1 = self.get_FiLM_param_(guidance, 1, guide=True)
i_alpha1, i_beta1 = self.get_FiLM_param_(input, 1)
guidance = affine_transformation(guidance, i_alpha1, i_beta1)
input = affine_transformation(input, g_alpha1, g_beta1)
input = self.activation(input)
guidance = self.activation(guidance)
input = self.conv1(input)
guidance = self.conv1_g(guidance)
g_alpha2, g_beta2 = self.get_FiLM_param_(guidance, 2, guide=True)
i_alpha2, i_beta2 = self.get_FiLM_param_(input, 2)
input = affine_transformation(input, g_alpha2, g_beta2)
guidance = affine_transformation(guidance, i_alpha2, i_beta2)
input = self.activation(input)
guidance = self.activation(guidance)
input = self.conv2(input)
guidance = self.conv2_g(guidance)
g_alpha3, g_beta3 = self.get_FiLM_param_(guidance, 3, guide=True)
i_alpha3, i_beta3 = self.get_FiLM_param_(input, 3)
input = affine_transformation(input, g_alpha3, g_beta3)
guidance = affine_transformation(guidance, i_alpha3, i_beta3)
input = self.activation(input)
guidance = self.activation(guidance)
input = self.conv3(input)
guidance = self.conv3_g(guidance)
g_alpha4, g_beta4 = self.get_FiLM_param_(guidance, 4, guide=True)
# guidance在这一步之后就舍弃了
input = affine_transformation(input, g_alpha4, g_beta4)
input = self.activation(input)
input = self.resnet(input)
input = self.pre_decoder(input)
output = self.decoder(input)
return output