Bringing Old Photos Back to Life模型代码分析1(数据载入部分)



(1)Bringing Old Photos Back to Life原理和测试        

(2)

Bringing Old Photos Back to Life模型代码分析1(数据载入部分) 

Bringing Old Photos Back to Life模型代码分析2(模型部分)

(3)Bringing Old Photos Back to Life数据集及其训练

这一部分是关于数据预处理部分

文件在Global/data下,如图所示

 base_dataset.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
import random
#
class BaseDataset(data.Dataset):
    def __init__(self):
        super(BaseDataset, self).__init__()

    def name(self):
        return 'BaseDataset'

    def initialize(self, opt):
        pass
#获取裁剪参数
# 这个函数是根据用户指定的方式resize或者crop出合适大小的输入尺寸。
# size:输入图片的尺寸
def get_params(opt, size):
    w, h = size
    new_h = h
    new_w = w
    if opt.resize_or_crop == 'resize_and_crop':
        # opt.loadSize为自己输入的尺寸,将图像缩放到这个大小
        new_h = new_w = opt.loadSize  # 将宽和高设置为同样大小

    if opt.resize_or_crop == 'scale_width_and_crop': # we scale the shorter side into 256

        if w<h:
            new_w = opt.loadSize
            new_h = opt.loadSize * h // w   # 高度按照原图宽高比计算
        else:
            new_h=opt.loadSize
            new_w = opt.loadSize * w // h

    if opt.resize_or_crop=='crop_only':
        pass


    x = random.randint(0, np.maximum(0, new_w - opt.fineSize))
    y = random.randint(0, np.maximum(0, new_h - opt.fineSize))
    #随机进行是否裁剪
    flip = random.random() > 0.5    # 随机数是否大于0.5,flip是bool型变量,此行代码意思为随机生成True或者False
    return {'crop_pos': (x, y), 'flip': flip} # 最终的返回值,在data.aligned_dataset 45行,当作params传入了下方get_transform()函数

# 图像变换
def get_transform(opt, params, method=Image.BICUBIC, normalize=True):
    transform_list = []
    #重设置大小
    if 'resize' in opt.resize_or_crop: # # 若opt.resize_or_crop中有'resize'
        osize = [opt.loadSize, opt.loadSize]
        transform_list.append(transforms.Scale(osize, method))   
    #
    elif 'scale_width' in opt.resize_or_crop:
    #    transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.loadSize, method)))  ## Here , We want the shorter side to match 256, and Scale will finish it.
    #将输入的`PIL.Image`重新改变大小成给定的`size`即256
        transform_list.append(transforms.Scale(256,method))
    #裁剪
    if 'crop' in opt.resize_or_crop:
        if opt.isTrain:
            # 使用transforms.Lambda封装其为transforms策略
            transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.fineSize)))
        else:
            if opt.test_random_crop:
                transform_list.append(transforms.RandomCrop(opt.fineSize))
            else:
                transform_list.append(transforms.CenterCrop(opt.fineSize))

    ## when testing, for ablation study, choose center_crop directly.



    if opt.resize_or_crop == 'none':
        base = float(2 ** opt.n_downsample_global)
        if opt.netG == 'local':
            base *= (2 ** opt.n_local_enhancers)
        transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base, method)))

    if opt.isTrain and not opt.no_flip:
        transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))

    transform_list += [transforms.ToTensor()]

    if normalize:
        # mean和std均为0.5
        transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
                                                (0.5, 0.5, 0.5))]
    return transforms.Compose(transform_list)
#归一化到(-1,1)
def normalize():    
    return transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
#将图片进行设置大小为base的整倍数
def __make_power_2(img, base, method=Image.BICUBIC):
    ow, oh = img.size        
    h = int(round(oh / base) * base)
    w = int(round(ow / base) * base)
    if (h == oh) and (w == ow):
        return img
    return img.resize((w, h), method)
#修改图片为目标大小
def __scale_width(img, target_width, method=Image.BICUBIC):
    ow, oh = img.size
    if (ow == target_width):
        return img    
    w = target_width
    h = int(target_width * oh / ow)    
    return img.resize((w, h), method)
#对图片进行切割# 随机平移滑动裁剪
def __crop(img, pos, size):
    ow, oh = img.size
    x1, y1 = pos
    tw = th = size   # 输入的尺寸 opt.fineSize
    if (ow > tw or oh > th):
        #Image.crop(left, up, right, below) 其中left:与左边界的距离 up:与上边界的距离 right:还是与左边界的距离 below:还是与上边界的距离
        return img.crop((x1, y1, x1 + tw, y1 + th))     # 随机裁剪,因为虽然每次裁剪测大小一样,但是起始点位置不一样
    return img
#左右翻转
def __flip(img, flip):
    if flip:
        return img.transpose(Image.FLIP_LEFT_RIGHT)
    return img

Create_Bigfile.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os
import struct
from PIL import Image

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]

# 判断文件夹中是否有以上类型图片,没有则返回0
def is_image_file(filename):
    #如果不都为空、0、false,则any()返回true
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

#创建图片数据集,存在列表中并返回
def make_dataset(dir):
    images = []
    assert os.path.isdir(dir), '%s is not a valid directory' % dir

    # os.walk(top[, topdown=True[, onerror=None[, followlinks=False]]]) 通过在目录树中游走输出在目录中的文件名,top返回三项(root,dirs,files),分别代表:
    # 当前正在遍历的这个文件夹的本身的地址;  list类型,内容是该文件夹中所有的目录的名字(不包括子目录);  list类型,内容是该文件夹中所有的文件(不包括子目录)
    for root, _, fnames in sorted(os.walk(dir)):
        for fname in fnames:
            if is_image_file(fname):
                #print(fname)
                #拼接出图片的地址,并加入到images列表
                path = os.path.join(root, fname)
                images.append(path)

    return images

### Modify these 3 lines in your own environment
#需要修改以下三个变量:
#变量一:存放待训练数据集文件夹的父目录
indir="/home/ziyuwan/workspace/data/temp_old"
#变量二:待训练数据的文件夹,共有三个目标文件夹,分别为 : VOC数据集(用于生成假老照片)、真实黑白老照片、真实彩色老照片
target_folders=['VOC','Real_L_old','Real_RGB_old']
#变量三:输出生成结果的文件夹路径
out_dir ="/home/ziyuwan/workspace/data/temp_old"
###

if os.path.exists(out_dir) is False:
    os.makedirs(out_dir)

#遍历存放数据集的文件夹
for target_folder in target_folders:
    #拼接生成存放数据集文件夹的路径
    curr_indir = os.path.join(indir, target_folder)
    #生成的大文件路径(含问文件名)
    curr_out_file = os.path.join(os.path.join(out_dir, '%s.bigfile' % (target_folder)))
    image_lists = make_dataset(curr_indir)
    image_lists.sort()
    with open(curr_out_file, 'wb') as wfid:
        # write total image number
        wfid.write(struct.pack('i', len(image_lists)))
        for i, img_path in enumerate(image_lists):
             # write file name first
             img_name = os.path.basename(img_path)
             img_name_bytes = img_name.encode('utf-8')
             wfid.write(struct.pack('i', len(img_name_bytes)))
             wfid.write(img_name_bytes)
    #
    #             # write image data in
             with open(img_path, 'rb') as img_fid:
                 img_bytes = img_fid.read()
             wfid.write(struct.pack('i', len(img_bytes)))
             wfid.write(img_bytes)

             if i % 1000 == 0:
                 print('write %d images done' % i)

custom_dataset_data_loader.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data
import random
from data.base_data_loader import BaseDataLoader
from data import online_dataset_for_old_photos as dts_ray_bigfile

#根据训练的模型模块不同,返回对应的数据集
def CreateDataset(opt):
    dataset = None
    # 训练A或者B时,使用的数据集为非成对数据集
    if opt.training_dataset=='domain_A' or opt.training_dataset=='domain_B':

        dataset = dts_ray_bigfile.UnPairOldPhotos_SR()
    #当训练mapping时,载入成对数据集
    if opt.training_dataset=='mapping':
        if opt.random_hole:
            dataset = dts_ray_bigfile.PairOldPhotos_with_hole()
        else:
            dataset = dts_ray_bigfile.PairOldPhotos()
    print("dataset [%s] was created" % (dataset.name()))   # 打印数据集名字为‘
    dataset.initialize(opt)  # 初始化数据集参数
    return dataset  # 返回创建好的数据集
##创建数据载入器# 加载数据集
class CustomDatasetDataLoader(BaseDataLoader):
    def name(self):
        return 'CustomDatasetDataLoader'

    def initialize(self, opt):
        BaseDataLoader.initialize(self, opt) ## 初始化参数
        #创建数据集
        self.dataset = CreateDataset(opt)
        #创建数据载入器
        self.dataloader = torch.utils.data.DataLoader( ## 加载创建好的数据集,并自定义相关参数
            self.dataset,
            batch_size=opt.batchSize,
            shuffle=not opt.serial_batches,
            num_workers=int(opt.nThreads),
            drop_last=True)

    def load_data(self):
        return self.dataloader           # 返回数据集

    def __len__(self):
        return min(len(self.dataset), self.opt.max_dataset_size)# 返回加载的数据集长度和一个epoch容许的加载最大容量

data_loader.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
#创建数据载入器
##########################################################################
# 创建数据集加载主函数
########################################################################
def CreateDataLoader(opt):
    from data.custom_dataset_data_loader import CustomDatasetDataLoader
    data_loader = CustomDatasetDataLoader()
    print(data_loader.name()) # 返回的名字为“CustomDatasetDataLoader”
    data_loader.initialize(opt) # # 初始化参数
    return data_loader

image_foder.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import torch.utils.data as data
from PIL import Image
import os

IMG_EXTENSIONS = [
    '.jpg', '.JPG', '.jpeg', '.JPEG',
    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff'
]


def is_image_file(filename):
    ### any()函数用于判断给定的可迭代参数iterable是否全部为False,则返回False,如果有一个为True,则返回True。
    # 元素除了是0、空、FALSE外都算TRUE。
    # 函数等价于:
    # def any(iterable):
    #     for element in iterable:
    #         if element:
    #             return True
    #     return False
    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)

# 制作数据集:获得数据集的图片路径列表
def make_dataset(dir):  # dir为数据集文件夹路径
    images = []# 创建空列表
    assert os.path.isdir(dir), '%s is not a valid directory' % dir   # 确认路径存在
    ### os.walk() 方法是一个简单易用的文件、目录遍历器,可以帮助我们高效的处理文件、目录方面的事情
    # top -- 是你所要遍历的目录的地址, 返回的是一个三元组(root,dirs,files)。
    # root 所指的是当前正在遍历的这个文件夹的本身的地址,和输入的os.walk(dir)种的dir一致
    # dirs 是一个 list ,内容是该文件夹中所有的 目录 的名字(不包括子目录),若无则为[]
    # files 同样是 list , 内容是该文件夹中所有的 文件 的名字(不包括子目录),若无则为[]
    for root, _, fnames in sorted(os.walk(dir)):   # fnames为文件中读取的照片文件
        for fname in fnames:
            if is_image_file(fname):
                path = os.path.join(root, fname)# 将文件夹路径dir 和 图片名称fname 结合起来
                images.append(path)  # 将图片路径存放到image列表里

    return images         # 返回图片路径列表


def default_loader(path):
    return Image.open(path).convert('RGB')


class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, return_paths=False,
                 loader=default_loader):
        imgs = make_dataset(root) # imgs为root目录下图片路径列表
        if len(imgs) == 0: # 图片数量 = 0 报错
            raise(RuntimeError("Found 0 images in: " + root + "\n"
                               "Supported image extensions are: " +
                               ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.transform = transform
        self.return_paths = return_paths
        self.loader = loader

    def __getitem__(self, index):
        path = self.imgs[index]  # 获取指定图片路径
        img = self.loader(path) # 加载图片
        if self.transform is not None:
            img = self.transform(img)    # 图片进行变换
        if self.return_paths:
            return img, path               # 返回图片和路径
        else:
            return img     # 仅返回图片

    def __len__(self):
        return len(self.imgs)      # 返回指定目录下图片数量

Load_Bigfile.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import io
import os
import struct
from PIL import Image
#载入打包好的数据
class BigFileMemoryLoader(object):
    def __load_bigfile(self):
        print('start load bigfile (%0.02f GB) into memory' % (os.path.getsize(self.file_path)/1024/1024/1024))
        with open(self.file_path, 'rb') as fid:
            self.img_num = struct.unpack('i', fid.read(4))[0]
            self.img_names = []
            self.img_bytes = []
            print('find total %d images' % self.img_num)
            for i in range(self.img_num):
                img_name_len = struct.unpack('i', fid.read(4))[0]
                img_name = fid.read(img_name_len).decode('utf-8')
                self.img_names.append(img_name)
                img_bytes_len = struct.unpack('i', fid.read(4))[0]
                self.img_bytes.append(fid.read(img_bytes_len))
                if i % 5000 == 0:
                    print('load %d images done' % i)
            print('load all %d images done' % self.img_num)
    #初始化
    def __init__(self, file_path):
        super(BigFileMemoryLoader, self).__init__()
        self.file_path = file_path
        self.__load_bigfile()
    #返回图片名字和图片
    def __getitem__(self, index):
        try:
            img = Image.open(io.BytesIO(self.img_bytes[index])).convert('RGB')
            return self.img_names[index], img
        except Exception:
            print('Image read error for index %d: %s' % (index, self.img_names[index]))
            return self.__getitem__((index+1)%self.img_num)

    #图片数目
    def __len__(self):
        return self.img_num

online_dataset_for_old_photos.py

# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.

import os.path
import io
import zipfile
from data.base_dataset import BaseDataset, get_params, get_transform, normalize
from data.image_folder import make_dataset
from PIL import Image
import torchvision.transforms as transforms
import numpy as np
from data.Load_Bigfile import BigFileMemoryLoader
import random
import cv2
from io import BytesIO
#图片转矩阵
def pil_to_np(img_PIL):
    '''Converts image in PIL format to np.array.

    From W x H x C [0...255] to C x W x H [0..1]
    '''
    ar = np.array(img_PIL)

    if len(ar.shape) == 3:
        ar = ar.transpose(2, 0, 1)
    else:
        ar = ar[None, ...]

    return ar.astype(np.float32) / 255.

#矩阵转图片
def np_to_pil(img_np):
    '''Converts image in np.array format to PIL image.

    From C x W x H [0..1] to  W x H x C [0...255]
    '''
    ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)

    if img_np.shape[0] == 1:
        ar = ar[0]
    else:
        ar = ar.transpose(1, 2, 0)

    return Image.fromarray(ar)
##
#以下合成噪声图片
##
def synthesize_salt_pepper(image,amount,salt_vs_pepper):

    ## Give PIL, return the noisy PIL

    img_pil=pil_to_np(image)

    out = img_pil.copy()
    p = amount
    q = salt_vs_pepper
    flipped = np.random.choice([True, False], size=img_pil.shape,
                               p=[p, 1 - p])
    salted = np.random.choice([True, False], size=img_pil.shape,
                              p=[q, 1 - q])
    peppered = ~salted
    out[flipped & salted] = 1
    out[flipped & peppered] = 0.
    noisy = np.clip(out, 0, 1).astype(np.float32)


    return np_to_pil(noisy)

def synthesize_gaussian(image,std_l,std_r):

    ## Give PIL, return the noisy PIL

    img_pil=pil_to_np(image)

    mean=0
    std=random.uniform(std_l/255.,std_r/255.)
    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
    noisy=img_pil+gauss
    noisy=np.clip(noisy,0,1).astype(np.float32)

    return np_to_pil(noisy)

def synthesize_speckle(image,std_l,std_r):

    ## Give PIL, return the noisy PIL

    img_pil=pil_to_np(image)

    mean=0
    std=random.uniform(std_l/255.,std_r/255.)
    gauss=np.random.normal(loc=mean,scale=std,size=img_pil.shape)
    noisy=img_pil+gauss*img_pil
    noisy=np.clip(noisy,0,1).astype(np.float32)

    return np_to_pil(noisy)

#图片缩小
def synthesize_low_resolution(img):
    w,h=img.size

    new_w=random.randint(int(w/2),w)
    new_h=random.randint(int(h/2),h)

    img=img.resize((new_w,new_h),Image.BICUBIC)

    if random.uniform(0,1)<0.5:
        img=img.resize((w,h),Image.NEAREST)
    else:
        img = img.resize((w, h), Image.BILINEAR)

    return img

#处理图片
def convertToJpeg(im,quality):
    #在内存中读写bytes
    with BytesIO() as f:
        im.save(f, format='JPEG',quality=quality)
        f.seek(0)
        #使用Image.open读出图像,然后转换为RGB通道,去掉透明通道A
        return Image.open(f).convert('RGB')

#由(高斯)噪声生成图片
def blur_image_v2(img):


    x=np.array(img)
    kernel_size_candidate=[(3,3),(5,5),(7,7)]
    kernel_size=random.sample(kernel_size_candidate,1)[0]
    std=random.uniform(1.,5.)

    #print("The gaussian kernel size: (%d,%d) std: %.2f"%(kernel_size[0],kernel_size[1],std))
    blur=cv2.GaussianBlur(x,kernel_size,std)

    return Image.fromarray(blur.astype(np.uint8))
#由以上噪声函数随机生成含有噪声的图片
def online_add_degradation_v2(img):

    task_id=np.random.permutation(4)

    for x in task_id:
        if x==0 and random.uniform(0,1)<0.7:
            img = blur_image_v2(img)
        if x==1 and random.uniform(0,1)<0.7:
            flag = random.choice([1, 2, 3])
            if flag == 1:
                img = synthesize_gaussian(img, 5, 50)
            if flag == 2:
                img = synthesize_speckle(img, 5, 50)
            if flag == 3:
                img = synthesize_salt_pepper(img, random.uniform(0, 0.01), random.uniform(0.3, 0.8))
        if x==2 and random.uniform(0,1)<0.7:
            img=synthesize_low_resolution(img)

        if x==3 and random.uniform(0,1)<0.7:
            img=convertToJpeg(img,random.randint(40,100))

    return img

#根据mask生成带有折痕的图片
#原论文中对于一些复杂的折痕会出现处理不佳的情况,在此进行改进,而不是简单进行加mask,
def irregular_hole_synthesize(img,mask):

    img_np=np.array(img).astype('uint8')
    mask_np=np.array(mask).astype('uint8')
    mask_np=mask_np/255
    img_new=img_np*(1-mask_np)+mask_np*255


    hole_img=Image.fromarray(img_new.astype('uint8')).convert("RGB")
    #L为灰度图像
    return hole_img,mask.convert("L")
#生成全黑三通道图像mask
def zero_mask(size):
    x=np.zeros((size,size,3)).astype('uint8')
    mask=Image.fromarray(x).convert("RGB")
    return mask


#非成对的老照片图像载入器(合成的老的和真实的老的照片,他们无需对应的,合成的老的照片由VOC数据集经处理生成)
class UnPairOldPhotos_SR(BaseDataset):  ## Synthetic + Real Old
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'domainA' in opt.name
        self.task = 'old_photo_restoration_training_vae'
        self.dir_AB = opt.dataroot
        # 载入VOC以及真实灰度、彩色图
        if self.isImage:

            self.load_img_dir_L_old=os.path.join(self.dir_AB,"Real_L_old.bigfile")
            self.load_img_dir_RGB_old=os.path.join(self.dir_AB,"Real_RGB_old.bigfile")
            self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")

            self.loaded_imgs_L_old=BigFileMemoryLoader(self.load_img_dir_L_old)
            self.loaded_imgs_RGB_old=BigFileMemoryLoader(self.load_img_dir_RGB_old)
            self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)

        else:
            # self.load_img_dir_clean=os.path.join(self.dir_AB,self.opt.test_dataset)
            self.load_img_dir_clean=os.path.join(self.dir_AB,"VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean=BigFileMemoryLoader(self.load_img_dir_clean)

        ####
        print("-------------Filter the imgs whose size <256 in VOC-------------")
        self.filtered_imgs_clean=[]
        # 过滤出VOC中小于256的图片
        for i in range(len(self.loaded_imgs_clean)):
            img_name,img=self.loaded_imgs_clean[i]
            h,w=img.size
            if h<256 or w<256:
                continue
            self.filtered_imgs_clean.append((img_name,img))

        print("--------Origin image num is [%d], filtered result is [%d]--------" % (
        len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
        ## Filter these images whose size is less than 256

        # self.img_list=os.listdir(load_img_dir)
        self.pid = os.getpid()

    def __getitem__(self, index):


        is_real_old=0

        sampled_dataset=None
        degradation=None
        #随机抽取一张图片(从合成的老照片 和 真实老照片 中)
        if self.isImage: 
            P=random.uniform(0,2)
            if P>=0 and P<1:
                if random.uniform(0,1)<0.5:
                    sampled_dataset=self.loaded_imgs_L_old
                    self.load_img_dir=self.load_img_dir_L_old
                else:
                    sampled_dataset=self.loaded_imgs_RGB_old
                    self.load_img_dir=self.load_img_dir_RGB_old
                is_real_old=1
            if P>=1 and P<2:
                sampled_dataset=self.filtered_imgs_clean
                self.load_img_dir=self.load_img_dir_clean
                degradation=1
        else:
            #载入过滤后小于256大小的图
            sampled_dataset=self.filtered_imgs_clean
            self.load_img_dir=self.load_img_dir_clean

        sampled_dataset_len=len(sampled_dataset)

        index=random.randint(0,sampled_dataset_len-1)

        img_name,img = sampled_dataset[index]

        if degradation is not None:
            #对图片进行降质做旧处理
            img=online_add_degradation_v2(img)

        path=os.path.join(self.load_img_dir,img_name)

        # AB = Image.open(path).convert('RGB')
        # split AB image into A and B

        # apply the same transform to both A and B
        #随机对图片转换为灰度图
        if random.uniform(0,1) <0.1:
            img=img.convert("L")
            img=img.convert("RGB")
            ## Give a probability P, we convert the RGB image into L

        #调整大小
        A=img
        w,h=A.size
        if w<256 or h<256:
            A=transforms.Scale(256,Image.BICUBIC)(A)
        # 将图片裁剪为256*256,对于一些小于256的老照片,先进行调整大小
        ## Since we want to only crop the images (256*256), for those old photos whose size is smaller than 256, we first resize them.
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)

        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)

        #存入字典
        input_dict = {'label': A_tensor, 'inst': is_real_old, 'image': A_tensor,
                        'feat': feat_tensor, 'path': path}
        return input_dict

    def __len__(self):
        return len(self.loaded_imgs_clean) ## actually, this is useless, since the selected index is just a random number

    def name(self):
        return 'UnPairOldPhotos_SR'

#成对图像载入器(原始图及其合成旧图)
class PairOldPhotos(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'imagegan' in opt.name
        self.task = 'old_photo_restoration_training_mapping'
        self.dir_AB = opt.dataroot
        #训练模式,载入VOC
        if opt.isTrain:
            self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)

            print("-------------Filter the imgs whose size <256 in VOC-------------")
            #过滤出VOC中小于256的图片
            self.filtered_imgs_clean = []
            for i in range(len(self.loaded_imgs_clean)):
                img_name, img = self.loaded_imgs_clean[i]
                h, w = img.size
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append((img_name, img))

            print("--------Origin image num is [%d], filtered result is [%d]--------" % (
            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))
        #测试模式时,仅载入测试集
        else:
            self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
            self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)

        self.pid = os.getpid()

    def __getitem__(self, index):


        #训练模式
        if self.opt.isTrain:
            #(B为清晰VOC数据集)
            img_name_clean,B = self.filtered_imgs_clean[index]
            path = os.path.join(self.load_img_dir_clean, img_name_clean)
            #生成成对图像(B为清晰VOC数据集,A对应的含噪声的图像)
            if self.opt.use_v2_degradation:
                A=online_add_degradation_v2(B)
            ### Remind: A is the input and B is corresponding GT
        else:
            #测试模式
            #(B为清晰VOC数据集,A对应的含噪声的图像)
            if self.opt.test_on_synthetic:

                img_name_B,B=self.loaded_imgs[index]
                A=online_add_degradation_v2(B)
                img_name_A=img_name_B
                path = os.path.join(self.load_img_dir, img_name_A)
            else:
                img_name_A,A=self.loaded_imgs[index]
                img_name_B,B=self.loaded_imgs[index]
                path = os.path.join(self.load_img_dir, img_name_A)

        #去掉透明通道
        if random.uniform(0,1)<0.1 and self.opt.isTrain:
            A=A.convert("L")
            B=B.convert("L")
            A=A.convert("RGB")
            B=B.convert("RGB")
        ## In P, we convert the RGB into L


        ##test on L

        # split AB image into A and B
        # w, h = img.size
        # w2 = int(w / 2)
        # A = img.crop((0, 0, w2, h))
        # B = img.crop((w2, 0, w, h))
        w,h=A.size
        if w<256 or h<256:
            A=transforms.Scale(256,Image.BICUBIC)(A)
            B=transforms.Scale(256, Image.BICUBIC)(B)

        # apply the same transform to both A and B
        #获取变换相关参数
        transform_params = get_params(self.opt, A.size)
        #变换数据,数据增强
        A_transform = get_transform(self.opt, transform_params)
        B_transform = get_transform(self.opt, transform_params)

        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)
        B_tensor = B_transform(B)

        input_dict = {'label': A_tensor, 'inst': inst_tensor, 'image': B_tensor,
                    'feat': feat_tensor, 'path': path}
        return input_dict

    def __len__(self):

        if self.opt.isTrain:
            return len(self.filtered_imgs_clean)
        else:
            return len(self.loaded_imgs)

    def name(self):
        return 'PairOldPhotos'

#成对带折痕图像载入器
class PairOldPhotos_with_hole(BaseDataset):
    def initialize(self, opt):
        self.opt = opt
        self.isImage = 'imagegan' in opt.name
        self.task = 'old_photo_restoration_training_mapping'
        self.dir_AB = opt.dataroot
        #训练模式下,载入成对的带有裂痕的合成图片
        if opt.isTrain:
            self.load_img_dir_clean= os.path.join(self.dir_AB, "VOC_RGB_JPEGImages.bigfile")
            self.loaded_imgs_clean = BigFileMemoryLoader(self.load_img_dir_clean)

            print("-------------Filter the imgs whose size <256 in VOC-------------")
            #过滤出大小小于256的图片
            self.filtered_imgs_clean = []
            for i in range(len(self.loaded_imgs_clean)):
                img_name, img = self.loaded_imgs_clean[i]
                h, w = img.size
                if h < 256 or w < 256:
                    continue
                self.filtered_imgs_clean.append((img_name, img))

            print("--------Origin image num is [%d], filtered result is [%d]--------" % (
            len(self.loaded_imgs_clean), len(self.filtered_imgs_clean)))

        else:
            self.load_img_dir=os.path.join(self.dir_AB,opt.test_dataset)
            self.loaded_imgs=BigFileMemoryLoader(self.load_img_dir)
        #载入不规则mask
        self.loaded_masks = BigFileMemoryLoader(opt.irregular_mask)

        self.pid = os.getpid()

    def __getitem__(self, index):



        if self.opt.isTrain:
            img_name_clean,B = self.filtered_imgs_clean[index]
            path = os.path.join(self.load_img_dir_clean, img_name_clean)


            B=transforms.RandomCrop(256)(B)
            A=online_add_degradation_v2(B)
            ### Remind: A is the input and B is corresponding GT

        else:
            img_name_A,A=self.loaded_imgs[index]
            img_name_B,B=self.loaded_imgs[index]
            path = os.path.join(self.load_img_dir, img_name_A)

            #A=A.resize((256,256))
            A=transforms.CenterCrop(256)(A)
            B=A

        if random.uniform(0,1)<0.1 and self.opt.isTrain:
            A=A.convert("L")
            B=B.convert("L")
            A=A.convert("RGB")
            B=B.convert("RGB")
        ## In P, we convert the RGB into L

        if self.opt.isTrain:
            #载入mask
            mask_name,mask=self.loaded_masks[random.randint(0,len(self.loaded_masks)-1)]
        else:
            # 载入mask
            mask_name, mask = self.loaded_masks[index%100]
        #调整mask大小
        mask = mask.resize((self.opt.loadSize, self.opt.loadSize), Image.NEAREST)

        if self.opt.random_hole and random.uniform(0,1)>0.5 and self.opt.isTrain:
            mask=zero_mask(256)

        if self.opt.no_hole:
            mask=zero_mask(256)

        #由mask合成带有折痕的图片
        A,_=irregular_hole_synthesize(A,mask)

        if not self.opt.isTrain and self.opt.hole_image_no_mask:
            mask=zero_mask(256)
        #获取做旧变换参数
        transform_params = get_params(self.opt, A.size)
        A_transform = get_transform(self.opt, transform_params)
        B_transform = get_transform(self.opt, transform_params)
        #对mask进行相同的左右翻转
        if transform_params['flip'] and self.opt.isTrain:
            mask=mask.transpose(Image.FLIP_LEFT_RIGHT)
        #归一化
        mask_tensor = transforms.ToTensor()(mask)


        B_tensor = inst_tensor = feat_tensor = 0
        A_tensor = A_transform(A)
        B_tensor = B_transform(B)

        input_dict = {'label': A_tensor, 'inst': mask_tensor[:1], 'image': B_tensor,
                    'feat': feat_tensor, 'path': path}
        return input_dict

    def __len__(self):

        if self.opt.isTrain:
            return len(self.filtered_imgs_clean)

        else:
            return len(self.loaded_imgs)

    def name(self):
        return 'PairOldPhotos_with_hole'

  • 2
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Alocus_

如果我的内容帮助到你,打赏我吧

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值