Down to the Last Detail: Virtual Try-on with Fine-grained Details, ACM MM 2020
Paper, Code/Model, ArXiv
深入到最后一个细节:能雕刻细节的虚拟试穿
理解得不对的地方欢迎评论区指正~
demo.py
foward函数(下)
#使用了torch.no_grad()上下文管理器,它会在其范围内禁用梯度计算,以减少内存占用并提高代码执行速度。
with torch.no_grad():
for i, result in enumerate(val_dataloader):#循环遍历val_dataloader中的数据。enumerate()函数可以同时返回数据的下标i和数据内容result
#服装空间对齐模块(Clothing Spatial Alignment Module):用于得到目标姿态下的变形后的服装图像。
'warped cloth'
warped_cloth = warped_image(gmm, result)#执行warped_image函数,将数据传入gmm生成器模型中,并输出warped_cloth,表示变形后的衣服图像。
if opt.warp_cloth:#如果指定了opt.warp_cloth参数为True,则将warped_cloth保存到文件中
warped_cloth_name = result['warped_cloth_name']
warped_cloth_path = os.path.join('dataset', 'warped_cloth', warped_cloth_name[0])
if not os.path.exists(os.path.split(warped_cloth_path)[0]):#判断 warped_cloth_path 路径所在的目录是否存在
os.makedirs(os.path.split(warped_cloth_path)[0])#不存在则使用 os.makedirs() 创建该目录
utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)
print('processing_%d'%i)
continue
#从result中提取出了多个tensor数据,将其赋值给对应的变量名。
source_parse = result['source_parse'].float().cuda()#输入的人物姿态解析图
target_pose_embedding = result['target_pose_embedding'].float().cuda()#目标姿态的嵌入向量
source_image = result['source_image'].float().cuda()#输入的人物图像
cloth_parse = result['cloth_parse'].cuda()#衣服的姿态解析图
cloth_image = result['cloth_image'].cuda()#衣服的图像
target_pose_img = result['target_pose_img'].float().cuda()#目标姿态的图像
cloth_parse = result['cloth_parse'].float().cuda()#转换成float类型
source_parse_vis = result['source_parse_vis'].float().cuda()#可视化的人物姿态解析图
#将衣服的信息添加到输入张量中
"filter add cloth infomation"
real_s = source_parse#赋值
index = [x for x in list(range(20)) if x != 5 and x != 6 and x != 7]#将real_s张量中索引为5、6和7的通道(通道数从0开始)从通道维度上去掉,得到real_s_张量
real_s_ = torch.index_select(real_s, 1, torch.tensor(index).cuda())#index_select函数从real_s张量的第1维度(即通道维度)中选择指定的索引,返回一个新的张量real_s_
input_parse = torch.cat((real_s_, target_pose_embedding, cloth_parse), 1).cuda()#将real_s_、target_pose_embedding和cloth_parse三个张量沿着通道维度(dim=1)拼接起来
#解析转换网络(Parsing Transformation Network):用于得到目标姿态下的目标语义图(人体分割图)。
'P'
generate_parse = generator_parsing(input_parse) # tanh 使用生成器模型generator_parsing对input_parse进行前向计算,生成一个张量generate_parse
generate_parse = F.softmax(generate_parse, dim=1)#softmax操作将每个元素变为正数,再将它们加起来变为1,得到了一个概率分布。
generate_parse_argmax = torch.argmax(generate_parse, dim=1, keepdim=True).float()#返回每一行中最大值的索引
res = []
for index in range(20):#for 循环将 generate_parse_argmax 的每一行转换成一个大小为 (batch_size, 20) 的 one-hot 向量。
res.append(generate_parse_argmax == index)#one-hot向量是指在一个固定长度的向量中,只有一个元素是1,其他元素都是0。这个唯一的1的位置表示一个特定的类别或状态,方便地处理分类问题。
generate_parse_argmax = torch.cat(res, dim=1).float()#将列表 res 中的所有 one-hot 向量按照第二个维度拼接起来,得到一个大小为 (batch_size, 20*20) 的 one-hot tensor,也就是 generate_parse_argmax。
#详细外观生成网络(Detailed Appearance Generation Network):CP-VTON模型中的一部分,用于生成初步的虚拟试衣结果。
"A"
image_without_cloth = create_part(source_image, source_parse, 'image_without_cloth', False)#生成源图像中不包含衣服的部分
input_app = torch.cat((image_without_cloth, warped_cloth, generate_parse), 1).cuda()#将 image_without_cloth、warped_cloth 和 generate_parse 拼接起来作为输入
generate_img = generator_app_cpvton(input_app)#送入generator_app_cpvton生成generate_img ,一个大小为 (batch_size, 6, H, W) 的张量,其中前3个通道是前景图像,后3个通道是遮罩图像。
p_rendered, m_composite = torch.split(generate_img, 3, 1) #通过 torch.split 将其拆分成前景图像 p_rendered 和遮罩图像 m_composite。
p_rendered = F.tanh(p_rendered)#p_rendered 通过 F.tanh 转换到 [-1, 1] 范围内
m_composite = F.sigmoid(m_composite)#m_composite 通过 F.sigmoid 转换到 [0, 1] 范围内
p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)#最终的虚拟试衣结果 p_tryon 通过加权平均的方式得到,其中 warped_cloth 的权重为遮罩图像,p_rendered 的权重为遮罩图像的补集。
refine_img = p_tryon
#面部细化网络(Face refinement Network):用于为面部区域增加更多的细节和真实感。
"F"
generate_face = create_part(refine_img, generate_parse_argmax, 'face', False)#从分割结果(generate_parse_argmax)中提取出人脸部分(generate_face)。
generate_img_without_face = refine_img - generate_face#与原始图像(refine_img)中的非人脸部分合成得到不包含人脸的图像(generate_img_without_face)。
source_face = create_part(source_image, source_parse, 'face', False)#从源图像(source_image)中提取出人脸部分(source_face)。
input_face = torch.cat((source_face, generate_face), 1)#将生成的人脸(generate_face)和源图像的人脸(source_face)连接起来
fake_face = generator_face(input_face)#送入生成器网络(generator_face)生成伪造的人脸(fake_face)。
fake_face = create_part(fake_face, generate_parse_argmax, 'face', False)
refine_img = generate_img_without_face + fake_face#将伪造的人脸(fake_face)插入到不包含人脸的图像(generate_img_without_face)中,得到最终的输出(refine_img)
#用于保存模型生成的结果图像
"generate parse vis"
if opt.save_time:
generate_parse_vis = source_parse_vis
else:
generate_parse_vis = torch.argmax(generate_parse, dim=1, keepdim=True).permute(0,2,3,1).contiguous()
generate_parse_vis = pose_utils.decode_labels(generate_parse_vis)#进行解码生成可视化的解析图像
"save results"
images = [source_image, cloth_image, target_pose_img, warped_cloth, source_parse_vis, generate_parse_vis, p_tryon, refine_img]#所需图像存储在 images 列表中
pose_utils.save_img(images, os.path.join(refine_path, '%d.jpg')%(i))#将列表中的图像保存在指定的路径下
torch.cuda.empty_cache()
main函数
if __name__ == "__main__":
#定义四个模型的预训练权重的路径
resume_gmm = "pretrained_checkpoint/step_009000.pth"
resume_G_parse = 'pretrained_checkpoint/parsing.tar'
resume_G_app_cpvton = 'pretrained_checkpoint/app.tar'
resume_G_face = 'pretrained_checkpoint/face.tar'
paths = [resume_gmm, resume_G_parse, resume_G_app_cpvton, resume_G_face]
#创建Config类的实例opt来接受传入参数
opt = Config().parse()
if not os.path.exists(opt.forward_save_path):
os.makedirs(opt.forward_save_path)
forward(opt, paths, 4, opt.forward_save_path)