虚拟试穿代码理解Down to the Last Detail: Virtual Try-on with Fine-grained Details(demo.py上)


Down to the Last Detail: Virtual Try-on with Fine-grained Details, ACM MM 2020
Paper, Code/Model, ArXiv
深入到最后一个细节:能雕刻细节的虚拟试穿
理解得不对的地方欢迎评论区指正~

demo.py

引用库:

  • torch:PyTorch的核心库,提供张量数据结构和数值计算操作等基本功能。
  • torch.nn:PyTorch中神经网络模块的库,提供各种层和模型的定义和实现。
  • models.networks:该脚本定义了自己的网络模型,其中包括Define_G和Define_D模型。
  • torch.optim:PyTorch中优化算法的库,包括SGD、Adam、Adagrad、AdamW等优化器。
  • config:一个Python脚本,定义了训练和测试的各种参数。
  • os:提供了与操作系统交互的函数,用于访问文件系统等操作。
  • os.path:提供了与路径相关的函数,用于处理文件和目录路径。
  • torch.utils.data:PyTorch中数据加载和预处理的库,包括DataLoader等数据加载器。
  • transforms:用于对图像进行变换的库,可以进行裁剪、缩放、旋转等操作。
  • data.regular_dataset:一个Python脚本,定义了数据集类,用于读取和处理训练和测试数据。
  • data.demo_dataset:一个Python脚本,定义了演示数据集类。
  • utils.transforms:自定义的图像变换函数。
  • time:Python的时间库,用于时间相关操作。
  • datetime:Python的日期和时间库,用于日期和时间相关操作。
  • torch.backends.cudnn:PyTorch针对GPU优化的库,用于提升训练速度。
  • numpy:Python的数值计算库,提供了大量数学函数和矩阵操作。
  • torchvision.utils:PyTorch中关于图像处理的工具函数。
  • PIL.Image:Python的图像处理库,提供了图像读取、保存、缩放、旋转等操作。
  • utils.pose_utils:自定义的人体姿态估计工具函数。
  • torch.nn.functional:PyTorch中一些常用的函数库。
  • lib.geometric_matching_multi_gpu:自定义的几何变换库,用于图像变形操作。
  • cv2:Python中OpenCV库的接口,用于图像和视频处理

load_model函数:

#load_model功能是加载预训练模型并返回预训练好的模型
def load_model(model, path):

    checkpoint = torch.load(path)#从指定路径(path)加载预训练模型(checkpoint)
    try:
        model.load_state_dict(checkpoint)#加载预训练模型的状态字典,以便将预训练模型的参数应用于该模型。
    except:
        model.load_state_dict(checkpoint.state_dict())#如果出现加载不成功的情况,使用 checkpoint.state_dict() 来加载
    model = model.cuda()#将模型加载到 GPU 中,以便后续在 GPU 上进行计算

    model.eval()#将模型设置为测试模式,即关闭 dropout 和 batch normalization 等对模型参数的影响。
    print(20*'=')
    for param in model.parameters():#冻结模型参数,避免在预测时更新模型参数,节省计算资源。
        param.requires_grad = False
#状态字典(state dictionary)是指在深度学习模型中保存了模型所有可学习参数的字典,每个参数都对应一个键值对。
# 在 PyTorch 中,状态字典通常是由模型的 state_dict() 方法返回的,它包含了所有的权重、偏差等参数信息。
# 状态字典可以通过调用 load_state_dict() 方法加载到模型中,使得模型的权重与状态字典中保存的权重保持一致。
# 通常在保存和加载模型时会用到状态字典,以便能够准确地保存和加载模型的权重,方便进行模型的训练和推理。

#Dropout是一种在深度神经网络中使用的正则化技术,可以在训练过程中减少过拟合现象。
#Dropout在每个训练批次中随机地使一些神经元失活,即将它们的输出设置为0,从而使得每个神经元都有一定的概率被临时忽略。
# 通过这种方式,dropout可以迫使神经元们相互独立地学习,而不是互相依赖,从而防止模型对训练集的过拟合。

#Batch normalization (批量归一化)是一种常用的神经网络正则化技术,旨在减轻训练深度神经网络时的内部协变量偏移 (internal covariate shift) 问题。
#内部协变量偏移是指在训练过程中,网络内部每一层的输入分布会随着网络的参数不断变化而发生变化,导致训练过程变得困难。
#Batch normalization 将每个 batch 的输入进行标准化处理,使其均值为0,方差为1,从而减少输入分布的变化,缓解了内部协变量偏移问题。
#同时,Batch normalization 也可以使得网络更加稳定,提高模型的收敛速度和准确性。Batch normalization 通常在网络的激活函数之前应用。

forward函数(上):

模型初始化模块:

def forward(opt, paths, gpu_ids, refine_path):#opt包含了试穿的相关参数,paths包含了用于加载模型的路径,gpu_ids指定了使用的GPU ID,refine_path指定了渲染的参考图片的路径。
    cudnn.enabled = True#启用cudnn并启用benchmark模式,以提高模型运行效率
    cudnn.benchmark = True#benchmark模式会根据输入数据的大小自动寻找最优算法来实现卷积、池化等操作,从而提高计算性能。
    opt.output_nc = 3#输出图像通道数为3

    gmm = GMM(opt)#创建一个GMM模型实例gmm
    gmm = torch.nn.DataParallel(gmm).cuda()#torch.nn.DataParallel进行封装,使其可以并行地在多个GPU上运行

    # 'batch'
    # 初始化三个生成器模型,对应人体解析、虚拟试穿和面部优化三阶段
    generator_parsing = Define_G(opt.input_nc_G_parsing, opt.output_nc_parsing, opt.ndf, opt.netG_parsing, opt.norm, 
                            not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)
    
    generator_app_cpvton = Define_G(opt.input_nc_G_app, opt.output_nc_app, opt.ndf, opt.netG_app, opt.norm, 
                            not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids, with_tanh=False)
    
    generator_face = Define_G(opt.input_nc_D_face, opt.output_nc_face, opt.ndf, opt.netG_face, opt.norm, 
                            not opt.no_dropout, opt.init_type, opt.init_gain, opt.gpu_ids)

    models = [gmm, generator_parsing, generator_app_cpvton, generator_face]#定义了一个列表 models,其中包含四个生成器模型。
    for model, path in zip(models, paths):#通过迭代模型列表 models 中的每个模型以及对应的路径 paths,将模型参数加载到模型中。
        load_model(model, path)#调用load_model函数来载入预训练模型的权重参数
    print('==>loaded model')

数据增强模块:

    augment = {}#创建空的字典augment,存放数据增强的配置信息

    #根据 PyTorch 版本选择不同的数据增强方式,数据增强(Data augmentation)是指在保持数据标签不变的前提下,对数据进行一定的变换,以增加数据量,提高模型的泛化能力。
    if '0.4' in torch.__version__:#如果 PyTorch 版本为 0.4,则将数据转换为一个三维张量,即 [C, H, W],其中 C 表示图像的通道数,H 和 W 分别表示图像的高度和宽度。
        #'3' 表示 RGB 彩色图像,'1' 表示单通道灰度图像
        augment['3'] = transforms.Compose([
                                    # transforms.Resize(256),
                                    transforms.ToTensor(),#将图像转换为 PyTorch 中的 tensor 类型,将像素值缩放到 [0, 1] 范围内,同时调整图像的通道顺序,变为 [C, H, W] 的形式。
                                    transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))#对数据进行归一化操作,将图像的每个通道的像素值缩放到 [-1, 1] 范围内,其中第一个元组参数表示每个通道的均值,第二个元组参数表示每个通道的标准差。
            ]) # change to [C, H, W]
        augment['1'] = augment['3']#将 augment['1'] 设置为和 augment['3'] 相同的数据增强方式

    else:
        augment['3'] = transforms.Compose([
                                # transforms.Resize(256),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
        ]) # change to [C, H, W]

        augment['1'] = transforms.Compose([
                                # transforms.Resize(256),
                                transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,))
        ]) # change to [C, H, W]

数据初始化模块:

    val_dataset = DemoDataset(opt, augment=augment)#创建了一个 DemoDataset 对象 val_dataset,用于加载验证数据集
    val_dataloader = DataLoader(#创建了一个 DataLoader 对象 val_dataloader。dataLoader是PyTorch中用于加载和迭代数据的一个工具。它可以自动将数据集划分为batch,同时还可以实现数据的并行加载
                    val_dataset,#从 val_dataset 中加载数据
                    shuffle=False,#不打乱数据
                    drop_last=False,#不丢弃数据集中不足一个 batch 的数据
                    num_workers=opt.num_workers,#表示使用多少个线程加载数据
                    batch_size = opt.batch_size_v,#将验证集的数据按照 batch_size_v 大小进行分批
                    pin_memory=True)#将数据加载到 CUDA 固定内存中,以提高数据加载效率。
  • 2
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值