虚拟试穿代码理解Down to the Last Detail: Virtual Try-on with Fine-grained Details(demo.py下)


Down to the Last Detail: Virtual Try-on with Fine-grained Details, ACM MM 2020
Paper, Code/Model, ArXiv
深入到最后一个细节:能雕刻细节的虚拟试穿
理解得不对的地方欢迎评论区指正~

demo.py

foward函数(下)

#使用了torch.no_grad()上下文管理器,它会在其范围内禁用梯度计算,以减少内存占用并提高代码执行速度。
    with torch.no_grad():
        for i, result in enumerate(val_dataloader):#循环遍历val_dataloader中的数据。enumerate()函数可以同时返回数据的下标i和数据内容result
            
            #服装空间对齐模块(Clothing Spatial Alignment Module):用于得到目标姿态下的变形后的服装图像。
            'warped cloth'
            warped_cloth = warped_image(gmm, result)#执行warped_image函数,将数据传入gmm生成器模型中,并输出warped_cloth,表示变形后的衣服图像。
            if opt.warp_cloth:#如果指定了opt.warp_cloth参数为True,则将warped_cloth保存到文件中
                warped_cloth_name = result['warped_cloth_name']
                warped_cloth_path = os.path.join('dataset', 'warped_cloth', warped_cloth_name[0])
                if not os.path.exists(os.path.split(warped_cloth_path)[0]):#判断 warped_cloth_path 路径所在的目录是否存在
                    os.makedirs(os.path.split(warped_cloth_path)[0])#不存在则使用 os.makedirs() 创建该目录
                utils.save_image(warped_cloth * 0.5 + 0.5, warped_cloth_path)
                print('processing_%d'%i)
                continue

            #从result中提取出了多个tensor数据,将其赋值给对应的变量名。 
            source_parse = result['source_parse'].float().cuda()#输入的人物姿态解析图
            target_pose_embedding = result['target_pose_embedding'].float().cuda()#目标姿态的嵌入向量
            source_image = result['source_image'].float().cuda()#输入的人物图像
            cloth_parse = result['cloth_parse'].cuda()#衣服的姿态解析图
            cloth_image = result['cloth_image'].cuda()#衣服的图像
            target_pose_img = result['target_pose_img'].float().cuda()#目标姿态的图像
            cloth_parse = result['cloth_parse'].float().cuda()#转换成float类型
            source_parse_vis = result['source_parse_vis'].float().cuda()#可视化的人物姿态解析图

            #将衣服的信息添加到输入张量中
            "filter add cloth infomation"
            real_s = source_parse#赋值
            index = [x for x in list(range(20)) if x != 5 and x != 6 and x != 7]#将real_s张量中索引为5、6和7的通道(通道数从0开始)从通道维度上去掉,得到real_s_张量
            real_s_ = torch.index_select(real_s, 1, torch.tensor(index).cuda())#index_select函数从real_s张量的第1维度(即通道维度)中选择指定的索引,返回一个新的张量real_s_
            input_parse = torch.cat((real_s_, target_pose_embedding, cloth_parse), 1).cuda()#将real_s_、target_pose_embedding和cloth_parse三个张量沿着通道维度(dim=1)拼接起来
            
            #解析转换网络(Parsing Transformation Network):用于得到目标姿态下的目标语义图(人体分割图)。
            'P'
            generate_parse = generator_parsing(input_parse) # tanh 使用生成器模型generator_parsing对input_parse进行前向计算,生成一个张量generate_parse
            generate_parse = F.softmax(generate_parse, dim=1)#softmax操作将每个元素变为正数,再将它们加起来变为1,得到了一个概率分布。

            generate_parse_argmax = torch.argmax(generate_parse, dim=1, keepdim=True).float()#返回每一行中最大值的索引
            res = []
            for index in range(20):#for 循环将 generate_parse_argmax 的每一行转换成一个大小为 (batch_size, 20) 的 one-hot 向量。
                res.append(generate_parse_argmax == index)#one-hot向量是指在一个固定长度的向量中,只有一个元素是1,其他元素都是0。这个唯一的1的位置表示一个特定的类别或状态,方便地处理分类问题。
            generate_parse_argmax = torch.cat(res, dim=1).float()#将列表 res 中的所有 one-hot 向量按照第二个维度拼接起来,得到一个大小为 (batch_size, 20*20) 的 one-hot tensor,也就是 generate_parse_argmax。

            #详细外观生成网络(Detailed Appearance Generation Network):CP-VTON模型中的一部分,用于生成初步的虚拟试衣结果。
            "A"
            image_without_cloth = create_part(source_image, source_parse, 'image_without_cloth', False)#生成源图像中不包含衣服的部分
            input_app = torch.cat((image_without_cloth, warped_cloth, generate_parse), 1).cuda()#将 image_without_cloth、warped_cloth 和 generate_parse 拼接起来作为输入
            generate_img = generator_app_cpvton(input_app)#送入generator_app_cpvton生成generate_img ,一个大小为 (batch_size, 6, H, W) 的张量,其中前3个通道是前景图像,后3个通道是遮罩图像。
            p_rendered, m_composite = torch.split(generate_img, 3, 1) #通过 torch.split 将其拆分成前景图像 p_rendered 和遮罩图像 m_composite。
            p_rendered = F.tanh(p_rendered)#p_rendered 通过 F.tanh 转换到 [-1, 1] 范围内
            m_composite = F.sigmoid(m_composite)#m_composite 通过 F.sigmoid 转换到 [0, 1] 范围内
            p_tryon = warped_cloth * m_composite + p_rendered * (1 - m_composite)#最终的虚拟试衣结果 p_tryon 通过加权平均的方式得到,其中 warped_cloth 的权重为遮罩图像,p_rendered 的权重为遮罩图像的补集。
            refine_img = p_tryon

            #面部细化网络(Face refinement Network):用于为面部区域增加更多的细节和真实感。
            "F"
            generate_face = create_part(refine_img, generate_parse_argmax, 'face', False)#从分割结果(generate_parse_argmax)中提取出人脸部分(generate_face)。
            generate_img_without_face = refine_img - generate_face#与原始图像(refine_img)中的非人脸部分合成得到不包含人脸的图像(generate_img_without_face)。
            source_face = create_part(source_image, source_parse, 'face', False)#从源图像(source_image)中提取出人脸部分(source_face)。
            input_face = torch.cat((source_face, generate_face), 1)#将生成的人脸(generate_face)和源图像的人脸(source_face)连接起来
            fake_face = generator_face(input_face)#送入生成器网络(generator_face)生成伪造的人脸(fake_face)。
            fake_face = create_part(fake_face, generate_parse_argmax, 'face', False)
            refine_img = generate_img_without_face + fake_face#将伪造的人脸(fake_face)插入到不包含人脸的图像(generate_img_without_face)中,得到最终的输出(refine_img)

            #用于保存模型生成的结果图像
            "generate parse vis"
            if opt.save_time:
                generate_parse_vis = source_parse_vis
            else:
                generate_parse_vis = torch.argmax(generate_parse, dim=1, keepdim=True).permute(0,2,3,1).contiguous()
                generate_parse_vis = pose_utils.decode_labels(generate_parse_vis)#进行解码生成可视化的解析图像
            "save results"
            images = [source_image, cloth_image, target_pose_img, warped_cloth, source_parse_vis, generate_parse_vis, p_tryon, refine_img]#所需图像存储在 images 列表中
            pose_utils.save_img(images, os.path.join(refine_path, '%d.jpg')%(i))#将列表中的图像保存在指定的路径下

    torch.cuda.empty_cache()

main函数

if __name__ == "__main__":
    #定义四个模型的预训练权重的路径
    resume_gmm = "pretrained_checkpoint/step_009000.pth"
    resume_G_parse = 'pretrained_checkpoint/parsing.tar'
    resume_G_app_cpvton = 'pretrained_checkpoint/app.tar'
    resume_G_face = 'pretrained_checkpoint/face.tar'
    paths = [resume_gmm, resume_G_parse, resume_G_app_cpvton, resume_G_face]
    
    #创建Config类的实例opt来接受传入参数
    opt = Config().parse()
    if not os.path.exists(opt.forward_save_path):
        os.makedirs(opt.forward_save_path)

    forward(opt, paths, 4, opt.forward_save_path)
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值