三、pix2pixHD代码解析(dataset处理)

pix2pixHD代码解析

一、pix2pixHD代码解析(train.py + test.py)
二、pix2pixHD代码解析(options设置)
三、pix2pixHD代码解析(dataset处理)
四、pix2pixHD代码解析(models搭建)

三、pix2pixHD代码解析(dataset处理)

data_loader.py

##########################################################################
# 创建数据集加载主函数
##########################################################################
def CreateDataLoader(opt):
    from data.custom_dataset_data_loader import CustomDatasetDataLoader
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name())                                             # 返回的名字为“CustomDatasetDataLoader”
    data_loader.initialize(opt)                                           # 初始化参数
    return data_loader

custom_dataset_data_loader.py

import torch.utils.data
from data.base_data_loader import BaseDataLoader


# 创建数据集
def CreateDataset(opt):
    dataset = None
    from data.aligned_dataset import AlignedDataset
    dataset = AlignedDataset()

    print("dataset [%s] was created" % (dataset.name()))               # 打印数据集名字为‘AlignedDataset’
    dataset.initialize(opt)                                            # 初始化数据集参数
    return dataset                                                     # 返回创建好的数据集


# 加载数据集
class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt)                           # 初始化参数
        self.dataset = CreateDataset(opt)                              # 创建数据集
        self.dataloader = torch.utils.data.DataLoader(                 # 加载创建好的数据集,并自定义相关参数
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads))

    def load_data(self):
        return self.dataloader                                         # 返回数据集

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)       # 返回加载的数据集长度和一个epoch容许的加载最大容量

aligned_dataset.py


#############################################################################
# 数据读取的方式
#############################################################################

import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image


# 返回一个字典,里面由整理好的数据集:图片 + 类别
class AlignedDataset(BaseDataset):                                           # init里面都是些路径的设置
    def initialize(self, opt):
        self.opt = opt
        self.root = opt.dataroot    

        ### input A (label maps)                                             # 标签图的路径
        dir_A = '_A' if self.opt.label_nc == 0 else '_label'
        self.dir_A = os.path.join(opt.dataroot, opt.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值