虚拟试穿
简介:本文梳理虚拟试穿算法框架结构,展示模特虚拟试穿上衣的效果,细说设计流程的详细步骤,提供相应的数据资源。
- 算法仓库:https://github.com/beauthy/DeepFashion_Try_On
github上不了,就访问:码云:虚拟试穿上衣测试:https://gitee.com/rpr/try-on_parse.git - 链接:https://pan.baidu.com/s/1nKUevnIMcGjaitVwIb7SRg
提取码:59wk - 测试用模型,鼓励大家根据网络训练自己的模型。
具体效果看下图or视频,想测试可以参见模型资源下载,之后Load和测试。
上述模型资源包括算法所设及的全部网络模型:latest_net_U.pth,latest_net_G1.pth,latest_net_G2.pth,latest_net_G.pth
测试效果:如图
测试效果:如视频
计算机视觉神经网络虚拟试穿测试
前言
本文将梳理算法实现过程原理。
提示:本文内容仅供学术研究与参考。
一、Try_On算法里面有什么?
0.环境; 1. 数据读取; 2. 数据模型:U-Net,G-Net; 3.损失函数; 4.调试常见的bug。二、梳理步骤
1.环境
代码如下(示例):
以上只等下次优化成requirements.txt再传上来。
2.读入数据
模型输入数据需要哪些呢?
测试数据集长什么样?
数据直观内容分析,我把模型需要的输入放一起,展示如下:
实际上,pose_关键点数据,和label_分割数据,是img_模特数据得到的(怎么生成关键点数据和人物分割数据的详细解读和代码,我再开一篇博客放上来);edge_数据就是待穿衣服color生成的。mask掩码数据是根据需要随机生成的。所以,完整的项目,的输入只需要模特和服装款式即可,也就是说可以实现给个人物和一件衣服就给实现换装。
再看,
看具体情况:通过photoshop的拾色器可以直观看到数据的值如下(label是灰度图):
背景的亮度L:0;面部的亮度L:9,左胳膊的亮度L:10,右胳膊的亮度L:8,上衣衣服位置的L:2。
把肢体图像分割出精确部分,用不同的亮度表示,到时候换衣服就有边界了。学姿势,纹理和褶皱等也有边界。
注意此处的L并不是该位置的像素值,只是亮度值。像素值可以用代码打印出来看。一下给出不同块儿的像素值。
# 具体划分区域Segmentation Label
0 -> Background
1 -> Hair
4 -> Upclothes
5 -> Left-shoe
6 -> Right-shoe
7 -> Noise
8 -> Pants
9 -> Left_leg
10 -> Right_leg
11 -> Left_arm
12 -> Face
13 -> Right_arm
注:名字带mask的三张掩码图,黑色区域亮度0,白色区域亮度为100,它们没有实际意义,可用于增加噪声,让模型稳定性好一些(我是这样理解的,因为训练的时候中间结果也有损失函数的backforword).
ok,以上就是输入数据,理清了没?
下面分析怎么使用的,以及后面模型是怎么组合的。
3. 输入配置
怎么读取数据集,生成模型输入需要的数据?项目写了一个配置options,专用于对数据集目录信息,训练测试信息和超参数进行配置的文件。
test.py
文件中:
opt = TrainOptions().parse()
ctrl+鼠标左键点击TestOptions
,找到opt对象的具体内容:
class TestOptions(BaseOptions):
def initialize(self):
BaseOptions.initialize(self)
......
ctrl+鼠标左键点击BaseOptions
:
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self):
.....
TestOptions
类的initialize
函数系重写,但还是调用了BaseOptions.initialize(self)
的,所以BaseOptions.initialize
的数据也包含了的。根据训练和测试所需要的数据不同,控制生成数据集和其他超参数。
4. 数据集处理详细
生成数据迭代器,同其他pytorch自制数据集相差不大。重点文件时aligned_dataset.py
具体一起来看看:
test.py
文件中:
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
CreateDataLoader
是封装数据迭代器的类:
点进去看一眼
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name())
data_loader.initialize(opt)
return data_loader
具体内容在CustomDatasetDataLoader
:
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt)
self.dataset = CreateDataset(opt)
self.dataloader = torch.utils.data.DataLoader(
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size)
发现还封装了一层,数据是:self.dataset = CreateDataset(opt)
:
发现还有函数封装,
def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name()))
dataset.initialize(opt)
return dataset
注意dataset.initialize(opt)
,封装数据过程中动不动都在初始化。
继续AlignedDataset
,ctrl+鼠标左键点击AlignedDataset
class AlignedDataset(BaseDataset):
def initialize(self, opt):
......
def __getitem__(self, index):
......
找到,def __getitem__(self, index):
,看到它是不是很眼熟了,就是pytorch生成batch数据可迭代数据。这个类继承于BaseDataset
,父类有transform等方法:
class BaseDataset(data.Dataset):
def __init__(self):
super(BaseDataset, self).__init__()
def name(self):
return 'BaseDataset'
def initialize(self, opt):
pass
def get_params(opt, size):
w, h = size
new_h = h
new_w = w
if opt.resize_or_crop == 'resize_and_crop':
new_h = new_w = opt.loadSize
elif opt.resize_or_crop == 'scale_width_and_crop':
new_w = opt.loadSize
new_h = opt.loadSize * h // w
x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
#flip = random.random() > 0.5
flip = 0
return {'crop_pos': (x, y), 'flip': flip}
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
transform_list = []
if 'resize' in opt.resize_or_crop:
osize = [opt.loadSize, opt.loadSize]
transform_list.append(transforms.Scale(osize, method))
elif 'scale_width' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))
osize = [256,192]
transform_list.append(transforms.Scale(osize, method))
if 'crop' in opt.resize_or_crop:
transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
if opt.resize_or_crop == 'none':
base = float(2 ** opt.n_downsample_global)
if opt.netG == 'local':
base *= (2 ** opt.n_local_enhancers)
transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))
if opt.isTrain and not opt.no_flip:
transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
transform_list += [transforms.ToTensor()]
if normalize:
transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
(0.5, 0.5, 0.5))]
return transforms.Compose(transform_list)
def normalize():
return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
def __make_power_2(img, base, method=Image.BICUBIC):
ow, oh = img.size
h = int(round(oh / base) * base)
w = int(round(ow / base) * base)
if (h == oh) and (w == ow):
return img
return img.resize((w, h), method)
def __scale_width(img, target_width, method=Image.BICUBIC):
ow, oh = img.size
if (ow == target_width):
return img
w = target_width
h = int(target_width * oh / ow)
return img.resize((w, h), method)
def __crop(img, pos, size):
ow, oh = img.size
x1, y1 = pos
tw = th = size
if (ow > tw or oh > th):
return img.crop((x1, y1, x1 + tw, y1 + th))
return img
def __flip(img, flip):
if flip:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img
既然找到数据了,就看看AlignedDataset
的数据生成方式吧。
首先是初始化,
def initialize(self, opt):
......
到这里还记得输入数据都有哪些吗?
待穿服装:color,轮廓edge;
模特:img,pose关键点,模特分割数据label;
掩码:两个黑背景掩码,一个白背景的掩码。
dir_C = '_color'
self.dir_C = os.path.join(opt.dataroot, opt.phase + dir_C)
self.C_paths = sorted(make_dataset(self.dir_C))
self.CR_paths = make_dataset(self.dir_C)
dir_E = '_edge'
self.dir_E = os.path.join(opt.dataroot, opt.phase + dir_E)
self.E_paths = sorted(make_dataset(self.dir_E))
self.ER_paths = make_dataset(self.dir_E)
dir_B = '_img'
self.dir_B = os.path.join(opt.dataroot, opt.phase + dir_B)
self.B_paths = sorted(make_dataset(self.dir_B))
self.BR_paths = sorted(make_dataset(self.dir_B))
# pose的关键点名称和img模特命名相差不大:pose_name = B_path.replace('.jpg', '_keypoints.json').replace('test_img', 'test_pose')
dir_A = '_label'
self.dir_A = os.path.join(opt.dataroot, opt.phase + dir_A)
self.A_paths = sorted(make_dataset(self.dir_A))
self.AR_paths = make_dataset(self.dir_A)
发现初始化就是把大家的地址放到明面上了,有的还排好了序,def __getitem__(self, index):
函数取数据就方便多了。
下面两个函数,就是具体生成路径字典或列表的:
def make_dataset(dir):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
f = dir.split('/')[-1].split('_')[-1]
print(dir, f)
dirs = os.listdir(dir)
for img in dirs:
path = os.path.join(dir, img)
# print(path)
images.append(path)
return images
def build_index(self, dirs):
for k, dir in enumerate(dirs):
name = dir.split('/')[-1]
name = name.split('-')[0]
# print(name)
for k, d in enumerate(dirs[max(k - 20, 0):k + 20]):
if name in d:
if name not in self.diction.keys():
self.diction[name] = []
self.diction[name].append(d)
else:
self.diction[name].append(d)
到这里了,就最后看一眼,生成的数据长什么样吧!
def __getitem__(self, index):
......
if self.opt.isTrain:
input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor,
'path': A_path, 'path_ref': AR_path,
'edge': E_tensor, 'color': C_tensor, 'mask': M_tensor, 'colormask': MC_tensor,
'pose': P_tensor, 'name': name
}
else:
input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}
return input_dict
对,就是他input_dict
,返回的字典。
注意:原作者共享的代码中,有这些后缀的文件夹,如图,我并没有相应后缀名文件,就配置训练模式,来测试数据了。
5.模型结构
继续看test.py文件,到模型板块
model = create_model(opt)
训练和测试使用的输入数据有区别的,训练的时候,除了将带穿衣服和模特的数据输入以外,还需要将穿好的结果输入,在最后面模型输出相比较,得出损失函数。
def create_model(opt):
if opt.model == 'pix2pixHD':
from .pix2pixHD_model import Pix2PixHDModel, InferenceModel
if opt.isTrain:
model = Pix2PixHDModel()
else:
model = InferenceModel()
model.initialize(opt)
if opt.verbose:
print("model [%s] was created" % (model.name()))
if opt.isTrain and len(opt.gpu_ids):
model = torch.nn.DataParallel(model, device_ids=opt.gpu_ids)
return model
测试model = InferenceModel()
类实例也是继承的Pix2PixHDModel
,重写了前向传播函数forward.
class InferenceModel(Pix2PixHDModel):
def forward(self, inp):
label = inp
return self.inference(label)
我们直接去看Pix2PixHDModel
:
pix2pixHD可以实现高分辨率图像生成和图片的语义编辑。
对于一个生成对抗网络(GAN),学习的关键就是理解生成器、判别器和损失函数这三部分。
pix2pixHD的生成器和判别器都是多尺度的,损失函数由GAN loss、Feature matching loss和Content loss组成。
class Pix2PixHDModel(BaseModel):
def name(self):
return 'Pix2PixHDModel'
def initialize(self, opt):
BaseModel.initialize(self, opt)
......
with torch.no_grad():
self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval()
self.G1 = networks.define_Refine(37, 14, self.gpu_ids).eval()
self.G2 = networks.define_Refine(19 + 18, 1, self.gpu_ids).eval()
self.G = networks.define_Refine(24, 3, self.gpu_ids).eval()
......
def forward(self,label,pre_clothes_mask,img_fore,clothes_mask,clothes,all_clothes_label,real_image,pose,mask):
......
代码贴太多了,容易看不清重点。我就只贴些关键的,帮助理清楚整个网络的脉络。
Pix2PixHDModel
继承了BaseModel
,BaseModel
类有初始化函数,save_network函数,load_network函数。就没什么可以看的了。
from . import networks
networks.py文件中写了多种网络的具体结构,用类封装:
pix2pixHD模型类,初始化的时候,会对需要用到的网络进行初始化:
然后,在forward函数里面,是组合网络使用的方法和顺序,以及哪些地方需要计算损失来约束网络。
def forward(self,label,pre_clothes_mask,img_fore,clothes_mask,clothes,all_clothes_label,real_image,pose,mask):
# Encode Inputs
#ipdb.set_trace()
input_label,masked_label,all_clothes_label= self.encode_input(label,clothes_mask,all_clothes_label)
#ipdb.set_trace()
arm1_mask=torch.FloatTensor((label.cpu().numpy()==11).astype(np.float)).cuda()
arm2_mask=torch.FloatTensor((label.cpu().numpy()==13).astype(np.float)).cuda()
pre_clothes_mask=torch.FloatTensor((pre_clothes_mask.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
clothes=clothes*pre_clothes_mask
......
forward
函数的输入数据是:label, pre_clothes_mask, img_fore, clothes_mask, clothes, all_clothes_label, real_image, pose, mask
.
我们输入的数据:input_dict = {'label': A_tensor, 'label_ref': AR_tensor, 'image': B_tensor, 'image_ref': BR_tensor, 'path': A_path, 'path_ref': AR_path}
做了一点预处理的(加入高斯噪声,wash the label):
data_loader = CreateDataLoader(opt)
dataset = data_loader.load_data()
......
for epoch in range(start_epoch, opt.niter + opt.niter_decay + 1):
epoch_start_time = time.time()
if epoch != start_epoch:
epoch_iter = epoch_iter % dataset_size
for i, data in enumerate(dataset, start=epoch_iter):
mask_clothes = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int))
mask_fore = torch.FloatTensor((data['label'].cpu().numpy() > 0).astype(np.int))
img_fore = data['image'] * mask_fore
img_fore_wc = img_fore * mask_fore
all_clothes_label = changearm(data['label'])
############## 模型向前传播 ######################
losses, fake_image, real_image, input_label, L1_loss, style_loss, clothes_mask, CE_loss, rgb, alpha = model(
Variable(data['label'].cuda()), Variable(data['edge'].cuda()), Variable(img_fore.cuda()),
Variable(mask_clothes.cuda()), Variable(data['color'].cuda()), Variable(all_clothes_label.cuda()),
Variable(data['image'].cuda()),
Variable(data['pose'].cuda()), Variable(data['image'].cuda()), Variable(mask_fore.cuda()))
data
里就是一次迭代获取dataset
里的一个batch
的数据。每个data
都是input_dict
样子,是字典。
看mask_clothes = torch.FloatTensor((data['label'].cpu().numpy() == 4).astype(np.int))
,将data
里label
对应的数据取出来处理(得到mask_clothes为区域分割数据label同尺寸的掩码图像,里面的值,原图等于4的是1,其他全0,就是把衣服区域取出来);label
是什么还记得吗?看下图:
看:mask_fore = torch.FloatTensor((data['label'].cpu().numpy() > 0).astype(np.int))
(得到mask_fore 为区域分割数据label同尺寸的掩码图像,里面的值,原图大于1的是1,其他全0)
看img_fore = data['image'] * mask_fore
他们相乘,就是在抠图,去掉背景得到img_fore
。
all_clothes_label = changearm(data['label'])
:调用changearm函数(变胳膊区域):
def changearm(old_label):
label = old_label
arm1 = torch.FloatTensor((data['label'].cpu().numpy() == 11).astype(np.int))
arm2 = torch.FloatTensor((data['label'].cpu().numpy() == 13).astype(np.int))
noise = torch.FloatTensor((data['label'].cpu().numpy() == 7).astype(np.int))
label = label * (1 - arm1) + arm1 * 4
label = label * (1 - arm2) + arm2 * 4
label = label * (1 - noise) + noise * 4
return label
changearm函数将模特区域分割数据label的左右胳膊取出来( == 11, == 13),把整幅噪声也找出来,将他们的值变成4.就像下图,把手也和衣服化为一个区域了。
############## Forward Pass ######################
losses, fake_image, real_image, input_label, L1_loss, style_loss, clothes_mask, CE_loss, rgb, alpha = model(
Variable(data['label'].cuda()), Variable(data['edge'].cuda()), Variable(img_fore.cuda()),
Variable(mask_clothes.cuda()), Variable(data['color'].cuda()), Variable(all_clothes_label.cuda()),
Variable(data['image'].cuda()),
Variable(data['pose'].cuda()), Variable(data['image'].cuda()), Variable(mask_fore.cuda()))
所以,输入的量有,label
–模特区域分割数据;edge
–衣服轮廓;img_fore
–去背景模特图像;mask_clothes
–模特正穿着的衣服的掩码区域,color
–衣服;all_clothes_label
–模特区域分割label将胳膊手融入衣服的区域分割图像;image
–模特;pose
–模特关键点;mask_fore
–模特区域;
对比看一看,模型中形参叫什么:
def forward(self,label,pre_clothes_mask,img_fore,
clothes_mask,clothes,all_clothes_label,
real_image,pose,grid,mask):
那就进入pix2pixHD模型的forward函数:
# Encode Inputs
input_label, masked_label, all_clothes_label = self.encode_input(label, clothes_mask, all_clothes_label)
对输入数据编码处理:
def encode_input(self, label_map, clothes_mask, all_clothes_label):
size = label_map.size()
oneHot_size = (size[0], 14, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
masked_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
masked_label = masked_label.scatter_(1, (label_map * (1 - clothes_mask)).data.long().cuda(), 1.0)
c_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
c_label = c_label.scatter_(1, all_clothes_label.data.long().cuda(), 1.0)
input_label = Variable(input_label)
return input_label, masked_label, c_label
手动实现one_hot 时,关于scatter_()函数: scatter_()函数有三个参数 scatter_(dim, index, src)
- dim指的是在哪个维度进行索引
- index指的是:用来进行索引的tensor
- src指scatter的源元素,可以是一个标量也可以是一个张量。
一句话解释上面的scatter:
input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
既input_label.scatter_(dim, index, src)将src中数据根据index中的索引按照dim的方向填进input_label中。
继续net_G1(conditional GAN):
G1_in = torch.cat([pre_clothes_mask, clothes, all_clothes_label, pose, self.gen_noise(shape)], dim=1)
arm_label = self.G1.refine(G1_in)
arm_label = self.sigmoid(arm_label)
CE_loss = self.cross_entropy2d(arm_label, (label * (1 - clothes_mask)).transpose(0, 1)[0].long()) * 10
直观了吧,看网络G1只有一个输出arm_label
。在训练中,就是模特换好新衣服后的分割图,与网络输出做损失反向传输。(模特原来是长袖,后面要穿短袖,自然胳膊是重点。脖子似乎没有人关心高领和低领的问题,待改进)。测试时直接生成要的穿color款衣服的模特分割图。
armlabel_map = generate_discrete_label(arm_label.detach(), 14, False)
dis_label = generate_discrete_label(arm_label.detach(), 14)
生成离散标签函数的输入是G1网络的输出结果arm_label
:
def generate_discrete_label(inputs, label_nc, onehot=True, encode=True):
pred_batch = []
size = inputs.size()
for input in inputs:
input = input.view(1, label_nc, size[2], size[3])
pred = np.squeeze(input.data.max(1)[1].cpu().numpy(), axis=0)
pred_batch.append(pred)
pred_batch = np.array(pred_batch)
pred_batch = torch.from_numpy(pred_batch)
label_map = []
for p in pred_batch:
p = p.view(1, 256, 192)
label_map.append(p)
label_map = torch.stack(label_map, 0)
if not onehot:
return label_map.float().cuda()
size = label_map.size()
oneHot_size = (size[0], label_nc, size[2], size[3])
input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
input_label = input_label.scatter_(1, label_map.data.long().cuda(), 1.0)
return input_label
上面要是看不明白出入输出变化,可以把结果或者结果的shape打印出来,对比着看。
继续net_G2:
G2_in = torch.cat([pre_clothes_mask, clothes, dis_label, pose, self.gen_noise(shape)], 1)
fake_cl = self.G2.refine(G2_in)
fake_cl = self.sigmoid(fake_cl)
CE_loss += self.BCE(fake_cl, clothes_mask) * 10
G2的输入,是G1的输出+Pose+color+edge+noise组合输入,输出为模特穿上新衣后衣服的轮廓。训练的时候,是模特穿上新的衣服数据的衣服轮廓与G2的输出做损失,反向传播的。测试时,G2输出模特换上新衣的轮廓数据,此时,还没有图案和纹理的变化。
损失函数BCE参考:https://blog.csdn.net/qq_22210253/article/details/85222093
继续:
fake_cl_dis = torch.FloatTensor((fake_cl.detach().cpu().numpy() > 0.5).astype(np.float)).cuda()
fake_cl_dis = morpho(fake_cl_dis, 1, True)
def morpho(mask, iter, bigger=True):
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
new = []
for i in range(len(mask)):
tem = mask[i].cpu().detach().numpy().squeeze().reshape(256, 192, 1) * 255
tem = tem.astype(np.uint8)
if bigger:
tem = cv2.dilate(tem, kernel, iterations=iter)
else:
tem = cv2.erode(tem, kernel, iterations=iter)
tem = tem.astype(np.float64)
tem = tem.reshape(1, 256, 192)
new.append(tem.astype(np.float64) / 255.0)
new = np.stack(new)
new = torch.FloatTensor(new).cuda()
return new
detach(): 神经网络的训练有时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整;或者值训练部分分支网络,并不让其梯度对主网络的梯度造成影响,torch.tensor.detach()和torch.tensor.detach_()函数来切断一些分支的反向传播。
cv2.getStructuringElement( ) 返回指定形状和尺寸的结构元素。
函数的第一个参数表示内核的形状,有三种形状可以选择。
矩形:MORPH_RECT;
交叉形:MORPH_CROSS;
椭圆形:MORPH_ELLIPSE;
第二和第三个参数分别是内核的尺寸以及锚点的位置。一般在调用erode以及dilate函数之前,先定义一个Mat类型的变量来获得getStructuringElement函数的返回值: 对于锚点的位置,有默认值Point(-1,-1),表示锚点位于中心点。element形状唯一依赖锚点位置,其他情况下,锚点只是影响了形态学运算结果的偏移。
cv2.erode()腐蚀:将前景物体变小,理解成将图像断开裂缝变大(在图片上画上黑色印记,印记越来越大)
dst = cv.erode(src, kernel[, dst[, anchor[, iterations[, borderType[, borderValue]]]]])
cv2.dilate()膨胀:将前景物体变大,理解成将图像断开裂缝变小(在图片上画上黑色印记,印记越来越小)
dst = cv2.dilate(src, kernel[, dst[, anchor[, iterations[, borderType[, borderValue]]]]])
numpy.stack(arrays, axis=0)
沿着新轴连接数组的序列。
axis参数指定新轴在结果尺寸中的索引。例如,如果axis=0,它将是第一个维度,如果axis=-1,它将是最后一个维度。
参数: 数组:array_like的序列每个数组必须具有相同的形状。axis:int,可选输入数组沿其堆叠的结果数组中的轴。
返回: 堆叠:ndarray堆叠数组比输入数组多一个维。
new_arm1_mask = torch.FloatTensor((armlabel_map.cpu().numpy() == 11).astype(np.float)).cuda()
new_arm2_mask = torch.FloatTensor((armlabel_map.cpu().numpy() == 13).astype(np.float)).cuda()
fake_cl_dis = fake_cl_dis * (1 - new_arm1_mask) * (1 - new_arm2_mask)
fake_cl_dis *= mask_fore
arm1_occ = clothes_mask * new_arm1_mask
arm2_occ = clothes_mask * new_arm2_mask
bigger_arm1_occ = morpho(arm1_occ, 10)
bigger_arm2_occ = morpho(arm2_occ, 10)
arm1_full = arm1_occ + (1 - clothes_mask) * arm1_mask
arm2_full = arm2_occ + (1 - clothes_mask) * arm2_mask
armlabel_map *= (1 - new_arm1_mask)
armlabel_map *= (1 - new_arm2_mask)
armlabel_map = armlabel_map * (1 - arm1_full) + arm1_full * 11
armlabel_map = armlabel_map * (1 - arm2_full) + arm2_full * 13
armlabel_map *= (1 - fake_cl_dis)
dis_label = encode(armlabel_map, armlabel_map.shape)
fake_c, warped, warped_mask, warped_grid = self.Unet(clothes, fake_cl_dis, pre_clothes_mask, grid)
mask = fake_c[:, 3, :, :]
mask = self.sigmoid(mask) * fake_cl_dis
fake_c = self.tanh(fake_c[:, 0:3, :, :])
fake_c = fake_c * (1 - mask) + mask * warped
skin_color = self.ger_average_color((arm1_mask + arm2_mask - arm2_mask * arm1_mask),
(arm1_mask + arm2_mask - arm2_mask * arm1_mask) * real_image)
occlude = (1 - bigger_arm1_occ * (arm2_mask + arm1_mask + clothes_mask)) * (
1 - bigger_arm2_occ * (arm2_mask + arm1_mask + clothes_mask))
img_hole_hand = img_fore * (1 - clothes_mask) * occlude * (1 - fake_cl_dis)
self.Unet = networks.define_UnetMask(4, self.gpu_ids).eval()
前面的G1,G2的网络我没有展开。后面会专门分析网络里面的组成,和输入输出等细节。在这里,我们溯源一下这个Unet:
def define_UnetMask(input_nc, gpu_ids=[]):
netG = UnetMask(input_nc, output_nc=4)
netG.cuda(gpu_ids[0])
netG.apply(weights_init)
return netG
Unet来源于UnetMask:
class UnetMask(nn.Module):
def __init__(self, input_nc, output_nc=3):
super(UnetMask, self).__init__()
self.stn = STNNet()
nl = nn.InstanceNorm2d
self.conv1 = nn.Sequential(*[nn.Conv2d(input_nc, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU(),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), nl(64), nn.ReLU()])
self.pool1 = nn.MaxPool2d(kernel_size=(2, 2))
......
def forward(self, input, refer, mask, grid):
input, warped_mask, rx, ry, cx, cy, grid = self.stn(input, torch.cat([mask, refer, input], 1), mask, grid)
# print(input.shape)
conv1 = self.conv1(torch.cat([refer.detach(), input.detach()], 1))
......
conv9 = self.conv9(torch.cat([conv1, up9], 1))
return conv9, input, warped_mask, grid
UnetMask有一个特殊的网络层STNNet:
class STNNet(nn.Module):
def __init__(self):
super(STNNet, self).__init__()
range = 0.9
r1 = range
r2 = range
grid_size_h = 5
grid_size_w = 5
assert r1 < 1 and r2 < 1 # if >= 1, arctanh will cause error in BoundedGridLocNet
target_control_points = torch.Tensor(list(itertools.product(
np.arange(-r1, r1 + 0.00001, 2.0 * r1 / (grid_size_h - 1)),
np.arange(-r2, r2 + 0.00001, 2.0 * r2 / (grid_size_w - 1)),
)))
Y, X = target_control_points.split(1, dim=1)
target_control_points = torch.cat([X, Y], dim=1)
self.target_control_points = target_control_points
# self.get_row(target_control_points,5)
GridLocNet = {
'unbounded_stn': UnBoundedGridLocNet,
'bounded_stn': BoundedGridLocNet,
}['bounded_stn']
self.loc_net = GridLocNet(grid_size_h, grid_size_w, target_control_points)
self.tps = TPSGridGen(256, 192, target_control_points)
def get_row(self, coor, num):
for j in range(num):
sum = 0
buffer = 0
flag = False
max = -1
for i in range(num - 1):
differ = (coor[j * num + i + 1, :] - coor[j * num + i, :]) ** 2
if not flag:
second_dif = 0
flag = True
else:
second_dif = torch.abs(differ - buffer)
buffer = differ
sum += second_dif
print(sum / num)
def get_col(self, coor, num):
for i in range(num):
sum = 0
buffer = 0
flag = False
max = -1
for j in range(num - 1):
differ = (coor[(j + 1) * num + i, :] - coor[j * num + i, :]) ** 2
if not flag:
second_dif = 0
flag = True
else:
second_dif = torch.abs(differ - buffer)
buffer = differ
sum += second_dif
print(sum)
def forward(self, x, reference, mask, grid_pic):
batch_size = x.size(0)
source_control_points, rx, ry, cx, cy = self.loc_net(reference)
source_control_points = (source_control_points)
# print('control points',source_control_points.shape)
source_coordinate = self.tps(source_control_points)
grid = source_coordinate.view(batch_size, 256, 192, 2)
# print('grid size',grid.shape)
transformed_x = grid_sample(x, grid, canvas=0)
warped_mask = grid_sample(mask, grid, canvas=0)
warped_gpic = grid_sample(grid_pic, grid, canvas=0)
return transformed_x, warped_mask, rx, ry, cx, cy, warped_gpic
U_net不仅含有简单的神经网络层,还有STN网络层(spatial transform network,空间变换网络)。
前面还完成了step3的过程:
以上,G1,G2,Unet,step3完成后,才是G3网络。G3的输入:
G_in = torch.cat([img_hole_hand, dis_label, fake_c, skin_color, self.gen_noise(shape)], 1)
fake_image = self.G.refine(G_in.detach())
fake_image = self.tanh(fake_image)
返回所有输出结果:
return [self.loss_filter(loss_G_GAN, 0, loss_G_VGG, loss_D_real, loss_D_fake), fake_image,
clothes, arm_label, L1_loss, style_loss, fake_cl, CE_loss, real_image, warped_grid]
到这里,算是解释完了。
总结
总结,流程图最左侧三个灰蓝色模型,为基本输入color(服装),和img(模特)的预处理。本项目中,已经提供了它们三个的输出(未提供相关模型,不是重点),作为输入。
所以整个网络的关键输入就是:edge,color;pose,img,label.
整个网络就是G1+G2+Unet+G3构成;中间数据输入做了些些掩码和校正外,没有其他结构了。
小伙伴们,是不是弄清楚?
本文主要是自己巩固一下学习内容,有朋友询问有没有简化输入的方法,能不能不提供解析数据(区域分割)和关键点之类的,只输入模特和服装就可以看到试穿效果的呢,答案:当然有。免解析虚拟试穿参见此博客,虽然测试效果还不好,但简化输入也有了解决方案不是。
谨以此文与大家共勉!如果你觉得对你有用,请给个点赞。
参考文献
感谢原作者:
[1] Yang, Han and Zhang, Ruimao and Guo, Xiaobao and Liu, Wei and Zuo, Wangmeng and Luo, Ping.Towards Photo-Realistic Virtual Try-On by Adaptively Generating-Preserving Image Content,IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR),June,2020.