PSGAN网络修改
之前搭建的PSGAN是半成品,attention部分只使用原图与参考图的feature map进行计算,没有使用每个像素点与landmarks的相对距离计算,所以这周对网络结构和数据处理部分进行了修改。
网络结构部分代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
from ops.spectral_norm import spectral_norm as SpectralNorm
# Defines the GAN loss which uses either LSGAN or the regular GAN.
# When LSGAN is used, it is basically same as MSELoss,
# but it abstracts away the need to create the target label tensor
# that has the same size as the input
class ResidualBlock(nn.Module):
"""Residual Block."""
def __init__(self, dim_in, dim_out, net_mode=None):
super(ResidualBlock, self).__init__()
if net_mode == 'MDNet' or (net_mode is None):
use_affine = True
elif net_mode == 'MANet':
use_affine = False
self.main = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=use_affine),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, kernel_size=3, stride=1, padding=1, bias=False),
nn.InstanceNorm2d(dim_out, affine=use_affine))
def forward(self, x):
return x + self.main(x)
class Discriminator(nn.Module):
"""Discriminator. PatchGAN."""
def __init__(self, image_size=128, conv_dim=64, repeat_num=3, norm='SN'):
super(Discriminator, self).__init__()
layers = []
if norm == 'SN':
layers.append(SpectralNorm(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1)))
else:
layers.append(nn.Conv2d(3, conv_dim, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = conv_dim
for i in range(1, repeat_num):
if norm == 'SN':
layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1)))
else:
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
# k_size = int(image_size / np.power(2, repeat_num))
if norm == 'SN':
layers.append(SpectralNorm(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1)))
else:
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=1, padding=1))
layers.append(nn.LeakyReLU(0.01, inplace=True))
curr_dim = curr_dim * 2
self.main = nn.Sequential(*layers)
if norm == 'SN':
self.conv1 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False))
else:
self.conv1 = nn.Conv2d(curr_dim, 1, kernel_size=4, stride=1, padding=1, bias=False)
# conv1 remain the last square size, 256*256-->30*30
# self.conv2 = SpectralNorm(nn.Conv2d(curr_dim, 1, kernel_size=k_size, bias=False))
# conv2 output a single number
def forward(self, x):
h = self.main(x)
out_makeup = self.conv1(h)
return out_makeup.squeeze()
class VGG(nn.Module):
def __init__(self, pool='max'):
super(VGG, self).__init__()
# vgg modules
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1)
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.conv3_4 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1)
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv4_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
self.conv5_4 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
if pool == 'max':
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
elif pool == 'avg':
self.pool1 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool2 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool3 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool4 = nn.AvgPool2d(kernel_size=2, stride=2)
self.pool5 = nn.AvgPool2d(kernel_size=2, stride=2)
def forward(self, x, out_keys):
out = {'r11': F.relu(self.conv1_1(x))}
out['r12'] = F.relu(self.conv1_2(out['r11']))
out['p1'] = self.pool1(out['r12'])
out['r21'] = F.relu(self.conv2_1(out['p1']))
out['r22'] = F.relu(self.conv2_2(out['r21']))
out['p2'] = self.pool2(out['r22'])
out['r31'] = F.relu(self.conv3_1(out['p2']))
out['r32'] = F.relu(self.conv3_2(out['r31']))
out['r33'] = F.relu(self.conv3_3(out['r32']))
out['r34'] = F.relu(self.conv3_4(out['r33']))
out['p3'] = self.pool3(out['r34'])
out['r41'] = F.relu(self.conv4_1(out['p3']))
out['r42'] = F.relu(self.conv4_2(out['r41']))
out['r43'] = F.relu(self.conv4_3(out['r42']))
out['r44'] = F.relu(self.conv4_4(out['r43']))
out['p4'] = self.pool4(out['r44'])
out['r51'] = F.relu(self.conv5_1(out['p4']))
out['r52'] = F.relu(self.conv5_2(out['r51']))
out['r53'] = F.relu(self.conv5_3(out['r52']))
out['r54'] = F.relu(self.conv5_4(out['r53']))
out['p5'] = self.pool5(out['r54'])
return [out[key] for key in out_keys]
# Makeup Apply Network(MANet)
class Generator(nn.Module):
"""Generator. Encoder-Decoder Architecture."""
def __init__(self, conv_dim=64):
super(Generator, self).__init__()
encoder_layers = [nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False),
nn.InstanceNorm2d(conv_dim, affine=False), nn.ReLU(inplace=True)]
# Down-Sampling
curr_dim = conv_dim
for i in range(2):
encoder_layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1, bias=False))
encoder_layers.append(nn.InstanceNorm2d(curr_dim * 2, affine=False))
encoder_layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
# Bottleneck
for i in range(3):
encoder_layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='MANet'))
decoder_layers = []
for i in range(3):
decoder_layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='MANet'))
# Up-Sampling
for i in range(2):
decoder_layers.append(
nn.ConvTranspose2d(curr_dim, curr_dim // 2, kernel_size=4, stride=2, padding=1, bias=False))
decoder_layers.append(nn.InstanceNorm2d(curr_dim // 2, affine=True))
decoder_layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim // 2
decoder_layers.append(nn.Conv2d(curr_dim, 3, kernel_size=7, stride=1, padding=3, bias=False))
decoder_layers.append(nn.Tanh())
self.encoder = nn.Sequential(*encoder_layers)
self.decoder = nn.Sequential(*decoder_layers)
self.MDNet = MDNet()
self.AMM = AMM()
def forward(self, source_image, mask_source, rel_pos_source, reference_image, mask_ref, rel_pos_ref):
fm_source = self.encoder(source_image)
fm_reference = self.MDNet(reference_image)
morphed_fm = self.AMM(fm_source, fm_reference, mask_source, mask_ref, rel_pos_source, rel_pos_ref)
result = self.decoder(morphed_fm)
return result
class MDNet(nn.Module):
"""Generator. Encoder-Decoder Architecture."""
# MDNet is similar to the encoder of StarGAN
def __init__(self, conv_dim=64):
super(MDNet, self).__init__()
layers = [nn.Conv2d(3, conv_dim, kernel_size=7, stride=1, padding=3, bias=False),
nn.InstanceNorm2d(conv_dim, affine=True), nn.ReLU(inplace=True)]
# Down-Sampling
curr_dim = conv_dim
for i in range(2):
layers.append(nn.Conv2d(curr_dim, curr_dim * 2, kernel_size=4, stride=2, padding=1, bias=False))
layers.append(nn.InstanceNorm2d(curr_dim * 2, affine=True))
layers.append(nn.ReLU(inplace=True))
curr_dim = curr_dim * 2
# Bottleneck
for i in range(3):
layers.append(ResidualBlock(dim_in=curr_dim, dim_out=curr_dim, net_mode='MDNet'))
self.main = nn.Sequential(*layers)
def forward(self, reference_image):
fm_reference = self.main(reference_image)
return fm_reference
# AMM参考 PSGAN 官方代码进行了修改
class AMM(nn.Module):
"""Attentive Makeup Morphing module"""
def __init__(self):
super(AMM, self).__init__()
self.visual_feature_weight = 0.01
self.lambda_matrix_conv = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)
self.beta_matrix_conv = nn.Conv2d(in_channels=256, out_channels=1, kernel_size=1)
self.softmax = nn.Softmax(dim=-1)
@staticmethod
def get_attention_map(mask_source, mask_ref, fm_source, fm_reference, rel_pos_source, rel_pos_ref):
HW = 64 * 64
batch_size = 3
# get 3 part fea using mask
channels = fm_reference.shape[1]
mask_source_re = F.interpolate(mask_source, size=64).repeat(1, channels, 1, 1) # (3, c, h, w)
fm_source = fm_source.repeat(3, 1, 1, 1) # (3, c, h, w)
# 计算 Attention 时 we only consider the pixels belonging to same facial region.
fm_source = fm_source * mask_source_re # (3, c, h, w) 3 stands for 3 parts
mask_ref_re = F.interpolate(mask_ref, size=64).repeat(1, channels, 1, 1)
fm_reference = fm_reference.repeat(3, 1, 1, 1)
fm_reference = fm_reference * mask_ref_re
theta_input = torch.cat((fm_source * 0.01, rel_pos_source), dim=1)
phi_input = torch.cat((fm_reference * 0.01, rel_pos_ref), dim=1)
theta_target = theta_input.view(batch_size, -1, HW) # (N, C+136, H*W)
theta_target = theta_target.permute(0, 2, 1) # (N, H*W, C+136)
phi_source = phi_input.view(batch_size, -1, HW) # (N, C+136, H*W)
weight = torch.bmm(theta_target, phi_source) # (3, HW, HW)
weight = weight.cpu()
weight_ind = torch.LongTensor(weight.detach().numpy().nonzero())
weight = weight.cuda()
weight_ind = weight_ind.cuda()
weight *= 200 # hyper parameters for visual feature
weight = F.softmax(weight, dim=-1)
weight = weight[weight_ind[0], weight_ind[1], weight_ind[2]]
# 那最后为什么不合成一个1*HW*HW的weight啊?
return torch.sparse.FloatTensor(weight_ind, weight, torch.Size([3, HW, HW]))
@staticmethod
def atten_feature(mask_ref, attention_map, old_gamma_matrix, old_beta_matrix):
# 论文中有说gamma和beta的想法源于style transfer,但不是general style transfer,所以这里要用mask计算每个facial region的style
batch_size, channels, width, height = old_gamma_matrix.size()
# channels = gamma_ref.shape[1]
mask_ref_re = F.interpolate(mask_ref, size=old_gamma_matrix.shape[2:]).repeat(1, channels, 1, 1)
gamma_ref_re = old_gamma_matrix.repeat(3, 1, 1, 1)
old_gamma_matrix = gamma_ref_re * mask_ref_re # (3, c, h, w)
print('old_gamma_matrix shape1: ', old_gamma_matrix.shape)
beta_ref_re = old_beta_matrix.repeat(3, 1, 1, 1)
old_beta_matrix = beta_ref_re * mask_ref_re
old_gamma_matrix = old_gamma_matrix.view(3, 1, -1)
print('old_gamma_matrix shape2: ', old_gamma_matrix.shape)
old_beta_matrix = old_beta_matrix.view(3, 1, -1)
old_gamma_matrix = old_gamma_matrix.permute(0, 2, 1)
old_beta_matrix = old_beta_matrix.permute(0, 2, 1)
print('old_gamma_matrix shape3: ', old_gamma_matrix.shape)
print('attention_map.to_dense() shape: ', attention_map.to_dense().shape)
new_gamma_matrix = torch.bmm(attention_map.to_dense(), old_gamma_matrix)
new_beta_matrix = torch.bmm(attention_map.to_dense(), old_beta_matrix)
gamma = new_gamma_matrix.view(-1, 1, width, height) # (3, c, h, w)
beta = new_beta_matrix.view(-1, 1, width, height)
gamma = (gamma[0] + gamma[1] + gamma[2]).unsqueeze(0) # (c, h, w) combine the three parts
beta = (beta[0] + beta[1] + beta[2]).unsqueeze(0)
return gamma, beta
def forward(self, fm_source, fm_reference, mask_source, mask_ref, rel_pos_source, rel_pos_ref):
# batch_size, channels, width, height = fm_reference.size()
old_gamma_matrix = self.lambda_matrix_conv(fm_reference)
old_beta_matrix = self.beta_matrix_conv(fm_reference)
attention_map = self.get_attention_map(mask_source, mask_ref, fm_source, fm_reference, rel_pos_source,
rel_pos_ref)
gamma, beta = self.atten_feature(mask_ref, attention_map, old_gamma_matrix, old_beta_matrix)
# 对feature_map_source进行修改
morphed_fm_source = fm_source * (1 + gamma) + beta
return morphed_fm_source
同时需要修改数据读入makeup_utils部分代码:
import torch.nn.functional as F
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms
import faceutils as futils
from ops.histogram_matching import *
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
def ToTensor(pic):
# handle PIL Image
if pic.mode == 'I':
img = torch.from_numpy(np.array(pic, np.int32, copy=False))
elif pic.mode == 'I;16':
img = torch.from_numpy(np.array(pic, np.int16, copy=False))
else:
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
# PIL image mode: 1, L, P, I, F, RGB, YCbCr, RGBA, CMYK
if pic.mode == 'YCbCr':
nchannel = 3
elif pic.mode == 'I;16':
nchannel = 1
else:
nchannel = len(pic.mode)
img = img.view(pic.size[1], pic.size[0], nchannel)
# put it from HWC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 1).transpose(0, 2).contiguous()
if isinstance(img, torch.ByteTensor):
return img.float()
else:
return img
def copy_area(tar, src, lms):
rect = [int(min(lms[:, 1])) - preprocess_image.eye_margin,
int(min(lms[:, 0])) - preprocess_image.eye_margin,
int(max(lms[:, 1])) + preprocess_image.eye_margin + 1,
int(max(lms[:, 0])) + preprocess_image.eye_margin + 1]
tar[:, :, rect[1]:rect[3], rect[0]:rect[2]] = \
src[:, :, rect[1]:rect[3], rect[0]:rect[2]]
def to_var(x, requires_grad=True):
if requires_grad:
return Variable(x).float()
else:
return Variable(x, requires_grad=requires_grad).float()
def preprocess_image(image: Image):
face = futils.dlib.detect(image)
assert face, "no faces detected"
# face[0]是第一个人脸,给定图片中只能有一个人脸
face = face[0]
image, face = futils.dlib.crop(image, face)
# detect landmark
lms = futils.dlib.landmarks(image, face) * 256 / image.width
lms = lms.round()
lms_eye_left = lms[42:48]
lms_eye_right = lms[36:42]
lms = lms.transpose((1, 0)).reshape(-1, 1, 1) # transpose to (y-x)
lms = np.tile(lms, (1, 256, 256)) # (136, h, w)
# calculate relative position for each pixel
fix = np.zeros((256, 256, 68 * 2))
for i in range(256): # row (y) h
for j in range(256): # column (x) w
fix[i, j, :68] = i
fix[i, j, 68:] = j
fix = fix.transpose((2, 0, 1)) # (136, h, w)
diff = to_var(torch.Tensor(fix - lms).unsqueeze(0), requires_grad=False)
# obtain face parsing result
image = image.resize((512, 512), Image.ANTIALIAS)
mask = futils.mask.mask(image).resize((256, 256), Image.ANTIALIAS)
mask = to_var(ToTensor(mask).unsqueeze(0), requires_grad=False)
mask_lip = (mask == 7).float() + (mask == 9).float()
mask_face = (mask == 1).float() + (mask == 6).float()
# 需要抠出 mask_eye
mask_eyes = torch.zeros_like(mask)
copy_area(mask_eyes, mask_face, lms_eye_left)
copy_area(mask_eyes, mask_face, lms_eye_right)
mask_eyes = to_var(mask_eyes, requires_grad=False)
mask_list = [mask_lip, mask_face, mask_eyes]
mask_aug = torch.cat(mask_list, 0) # (3, 1, h, w)
# 根据给定 size 或 scale_factor,上采样或下采样输入数据input
mask_re = F.interpolate(mask_aug, size=preprocess_image.diff_size).repeat(1, diff.shape[1], 1,
1) # (3, 136, 64, 64)
diff_re = F.interpolate(diff, size=preprocess_image.diff_size).repeat(3, 1, 1, 1) # (3, 136, 64, 64)
# 这就是论文里计算attention时要求同一个facial region
diff_re = diff_re * mask_re # (3, 136, 64, 64)
# dim=1,求出的norm就是(3, 1, 64, 64),也就是relative position的范数值
norm = torch.norm(diff_re, dim=1, keepdim=True).repeat(1, diff_re.shape[1], 1, 1)
# torch.where()函数的作用是按照一定的规则合并两个tensor类型
norm = torch.where(norm == 0, torch.tensor(1e10), norm)
diff_re /= norm
image = image.resize((256, 256), Image.ANTIALIAS)
real = to_var(transform(image).unsqueeze(0))
return [real, mask_aug, diff_re]
def preprocess_train_image(image: Image, mask, diff_re):
real = transform(image).unsqueeze(0)
mask_aug = mask
diff_re = diff_re
return [real, mask_aug, diff_re]
# parameter of eye transfer
preprocess_image.eye_margin = 16
# down sample size
preprocess_image.diff_size = (64, 64)