pix2pixHD代码解析
一、pix2pixHD代码解析(train.py + test.py)
二、pix2pixHD代码解析(options设置)
三、pix2pixHD代码解析(dataset处理)
四、pix2pixHD代码解析(models搭建)
三、pix2pixHD代码解析(dataset处理)
data_loader.py
##########################################################################
# 创建数据集加载主函数
##########################################################################
def CreateDataLoader(opt):
from data.custom_dataset_data_loader import CustomDatasetDataLoader
data_loader = CustomDatasetDataLoader()
print(data_loader.name()) # 返回的名字为“CustomDatasetDataLoader”
data_loader.initialize(opt) # 初始化参数
return data_loader
custom_dataset_data_loader.py
import torch.utils.data
from data.base_data_loader import BaseDataLoader
# 创建数据集
def CreateDataset(opt):
dataset = None
from data.aligned_dataset import AlignedDataset
dataset = AlignedDataset()
print("dataset [%s] was created" % (dataset.name())) # 打印数据集名字为‘AlignedDataset’
dataset.initialize(opt) # 初始化数据集参数
return dataset # 返回创建好的数据集
# 加载数据集
class CustomDatasetDataLoader(BaseDataLoader):
def name(self):
return 'CustomDatasetDataLoader'
def initialize(self, opt):
BaseDataLoader.initialize(self, opt) # 初始化参数
self.dataset = CreateDataset(opt) # 创建数据集
self.dataloader = torch.utils.data.DataLoader( # 加载创建好的数据集,并自定义相关参数
self.dataset,
batch_size=opt.batchSize,
shuffle=not opt.serial_batches,
num_workers=int(opt.nThreads))
def load_data(self):
return self.dataloader # 返回数据集
def __len__(self):
return min(len(self.dataset), self.opt.max_dataset_size) # 返回加载的数据集长度和一个epoch容许的加载最大容量
aligned_dataset.py
#############################################################################
# 数据读取的方式
#############################################################################
import os.path
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
# 返回一个字典,里面由整理好的数据集:图片 + 类别
class AlignedDataset(BaseDataset): # init里面都是些路径的设置
def initialize(self, opt):
self.opt = opt
self.root = opt.dataroot
### input A (label maps) # 标签图的路径
dir_A = '_A' if self.opt.label_nc == 0 else '_label'
self.dir_A = os.path.join(opt.dataroot, opt.