彩色星球图片生成2:同时使用传统Gan判别器和马尔可夫判别器(pytorch版)


上一集: 彩色星球图片生成1:使用Gan实现(pytorch版)

1. 描述

数据集使用和上次使用Gan实现彩色星球图片生成(pytorch版)一样的数据集,仍然使用其中的32张星球图片,同时将Batch size设置为32来验证实验效果。
训练图片(32张,来源space engine):
在这里插入图片描述

1.1 传统Gan判别器训练结果

在传统Gan判别器中,判别器输出为单个值,表示真假,在训练之后得到的图像如下:
在这里插入图片描述
结论:图像整体轮廓为圆形,但细节区域并不是很清晰。

1.2 马尔可夫(PatchGan)判别器训练结果

PatchGan判别器并不使用线性层,输出为一个矩阵,每个值表示原图中某个区块对应的真假,取平均值作为均衡之后的结果。

注:作为初学者,在这里将矩阵展平为一维矩阵后,使用了等长的一维矩阵作为真值标签然后计算MSELoss,并不能保证这个方法的正确性,属于自己的理解,不过最后的结果还是可观的,如有问题欢迎大佬指正。

单独使用了Patch判别器的训练代码并未单独给出,但相关代码实质都已包含在下面的复合代码中。
单独使用Patch判别器训练之后得到的生成结果如下:
在这里插入图片描述
结论:图像细节变得更加清晰,但是在整体结构上却不再是标准的圆形。由于Patch判别器关注局部的原因,可以明显发现图像由多个半径不同的圆弧构成,导致出现了漩涡状边缘。

1.3 同时使用两种判别器

在前两次实验的基础上,决定尝试同时使用两种判别器,单独更新两个判别器之后,用两个判别器loss的和来更新生成器,试图兼顾整体轮廓与局部细节。
实验输入训练图像的像素为264x264,可以通过将config.from_old_model设置为True来读取上一次训练保存的网络与优化器参数,从而在上次训练的结果上继续训练下去,默认100个epoch保存一次,可在config中配置。

2. 代码

依然分为主体代码main.py与模型代码model.py两部分。

2.1 模型代码model.py

import torch
import torch.nn as nn


# 生成器,基于上采样
class G_net(nn.Module):
    def __init__(self):
        super(G_net, self).__init__()
        self.expand = nn.Sequential(
            nn.Linear(128, 2048),
            nn.BatchNorm1d(2048),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.2),
            nn.Linear(2048, 4096),
            nn.BatchNorm1d(4096),
            nn.Dropout(0.5),
            nn.LeakyReLU(0.2),
        )
        self.gen = nn.Sequential(
            # 反卷积扩张尺寸,保持kernel size能够被stride整除来减少棋盘效应
            nn.ConvTranspose2d(64, 128, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(256, 64, kernel_size=6, stride=3, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(32),
            nn.LeakyReLU(0.2),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2),
            # 尾部添加正卷积压缩减少棋盘效应
            nn.Conv2d(16, 8, kernel_size=5, stride=1, padding=1),
            nn.BatchNorm2d(8),
            nn.LeakyReLU(0.2),
            nn.Conv2d(8, 4, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(4, 3, kernel_size=3, stride=1, padding=1),

            # 将输出约束到[-1,1]
            nn.Tanh()
        )

    def forward(self, img_seeds):
        img_seeds = self.expand(img_seeds)
        # 将线性数据重组为二维图片
        img_seeds = img_seeds.view(-1, 64, 8, 8)
        output = self.gen(img_seeds)
        return output


# 全局判别器,传统gan
class D_net_global(nn.Module):
    def __init__(self):
        super(D_net_global,self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 16, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, True),
        )
        self.classifier = nn.Sequential(
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        features = self.features(img)
        features = features.view(features.shape[0], -1)
        output = self.classifier(features)
        return output


# 局部判别器,patchgan
class D_net_patch(nn.Module):
    def __init__(self):
        super(D_net_patch,self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=4, stride=2, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 1, kernel_size=4, stride=1, bias=False),
            nn.Sigmoid(),
        )

    def forward(self, img):
        features = self.features(img)
        features = features.view(features.shape[0], -1)
        return features


# 返回对应的生成器
def get_G_model(from_old_model, device, model_path):
    model = G_net()
    # 从磁盘加载之前保存的模型参数
    if from_old_model:
        model.load_state_dict(torch.load(model_path))
    # 将模型加载到用于运算的设备的内存
    model = model.to(device)
    return model

# 返回全局判别器的模型
def get_D_model_global(from_old_model, device, model_path):
    model = D_net_global()
    # 从磁盘加载之前保存的模型参数
    if from_old_model:
        model.load_state_dict(torch.load(model_path))
    # 将模型加载到用于运算的设备的内存
    model = model.to(device)
    return model

# 返回局部判别器的模型
def get_D_model_patch(from_old_model, device, model_path):
    model = D_net_patch()
    # 从磁盘加载之前保存的模型参数
    if from_old_model:
        model.load_state_dict(torch.load(model_path))
    # 将模型加载到用于运算的设备的内存
    model = model.to(device)
    return model

2.2 训练代码main.py

from torch.utils.data import Dataset, DataLoader
import time
from torch.optim import AdamW
from model import *
from torchvision.utils import save_image
import random
from torch.autograd import Variable
import os
import cv2
from albumentations import Normalize, Compose, Resize, IAAAdditiveGaussianNoise, GaussNoise
from albumentations.pytorch import ToTensorV2
from apex import amp
import pickle


# ------------------------------------config------------------------------------
class config:
    # 设置种子数,配置是否要固定种子数
    seed = 26
    use_seed = True

    # 配置是否要从磁盘加载之前保存的模型参数继续训练
    from_old_model = False

    # 使用apex加速训练
    use_apex = True

    # 运行多少个epoch之后停止
    epochs = 20000
    # 配置batch size
    batchSize = 32

    # 训练图片输入分辨率
    img_size = 264

    # 配置喂入生成器的随机正态分布种子数有多少维(如果改动,需要在model中修改网络对应参数)
    img_seed_dim = 128

    # 有多大概率在训练判别器D时交换正确图片的标签和伪造图片的标签
    D_train_label_exchange = 0.1

    # 是否在图片中添加噪声
    add_noise = False

    # 定义每张图对应的lable的长度,基于patch判别器结构与输入分辨率计算得出
    patch_label_length = 121

    # 将数据集保存在内存中还是磁盘中
    # 小型数据集可以整个载入内存加快速度
    read_from = "Memory"
    # read_from = "Disk"

    # 保存模型参数文件的路径
    G_model_path = "G_model.pth"
    D_model_global_path = "D_model_global.pth"
    D_model_patch_path = "D_model_patch.pth"

    # 保存优化器参数文件的路径
    G_optimizer_path = "G_optimizer.pth"
    D_optimizer_global_path = "D_optimizer_global.pth"
    D_optimizer_patch_path = "D_optimizer_patch.pth"

    # 保存当前保存模型的历史总计训练epoch数
    epoch_record_path = "epoch_count.pkl"

    # 损失函数
    # 使用均方差损失函数
    criterion = nn.MSELoss()

    # 多少个epoch之后保存一次模型
    save_step = 100

    # ------------------------------------路径配置------------------------------------
    # 数据集来源
    img_path = "train_images/"
    # 输出图片的文件夹路径
    output_path = "output_images/"

    # 如果继续训练,则读取之前进行过多少次epoch的训练
    if from_old_model:
        with open(epoch_record_path, "rb") as file:
            last_epoch_number = pickle.load(file)
    else:
        last_epoch_number = 0


# 固定随机数种子
def seed_all(seed):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True


if config.use_seed:
    seed_all(seed=config.seed)

# -----------------------------------transforms------------------------------------
def get_transforms(img_size):
    # 缩放分辨率并转换到0-1之间
    return Compose(
        [Resize(img_size, img_size),
         Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, p=1.0),
         ToTensorV2(p=1.0)]
    )


# 在图像输入鉴别器D之前对其添加随机噪声
# 此处写法参考:https://blog.csdn.net/Mr_Lowbee/article/details/107990345
# 注:在训练过程中实际并没有启用噪声
def add_noise(image, sigma=20):
    if config.add_noise == True:
        sigma = sigma / 255
        noise_img = image + sigma * torch.randn_like(image)
        noise_img = noise_img.clamp(0, 1)
        return noise_img
    else:
        return image


# ------------------------------------dataset------------------------------------
# 从磁盘读取数据的dataset
if config.read_from == "Disk":
    class image_dataset(Dataset):
        def __init__(self, file_list, img_path, transform):
            # files list
            self.file_list = file_list
            self.img_path = img_path
            self.transform = transform

        def __getitem__(self, index):
            image_path = self.img_path + self.file_list[index]
            img = cv2.imread(image_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            img = self.transform(image=img)['image']
            return img

        def __len__(self):
            return len(self.file_list)

# 从内存读取数据的dataset
elif config.read_from == "Memory":
    class image_dataset(Dataset):
        def __init__(self, file_list, img_path, transform):
            self.imgs = []
            for file in file_list:
                image_path = img_path + file
                img = cv2.imread(image_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                img = transform(image=img)['image']
                self.imgs.append(img)

        def __getitem__(self, index):
            return self.imgs[index]

        def __len__(self):
            return len(self.imgs)


# ------------------------------------main------------------------------------
def main():
    # 如果可以使用GPU运算,则使用GPU,否则使用CPU
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("Use " + str(device))

    # 创建输出文件夹
    if not os.path.exists(config.output_path):
        os.mkdir(config.output_path)

    # 创建dataset
    # create dataset
    file_list = None
    for path, dirs, files in os.walk(config.img_path, topdown=False):
        file_list = list(files)

    train_dataset = image_dataset(file_list, config.img_path, transform=get_transforms(config.img_size))
    train_loader = DataLoader(dataset=train_dataset, batch_size=config.batchSize, shuffle=True)

    # 从model中获取判别器D和生成器G的网络模型
    # 判别器分为global全局判别器与patch局部判别器
    G_model = get_G_model(config.from_old_model, device, config.G_model_path)
    D_model_global = get_D_model_global(config.from_old_model, device, config.D_model_global_path)
    D_model_patch = get_D_model_patch(config.from_old_model, device, config.D_model_patch_path)

    # 定义G和D的优化器,此处使用AdamW优化器
    G_optimizer = AdamW(G_model.parameters(), lr=3e-4, weight_decay=1e-6)
    D_optimizer_global = AdamW(D_model_global.parameters(), lr=3e-4, weight_decay=1e-6)
    D_optimizer_patch = AdamW(D_model_patch.parameters(), lr=3e-4, weight_decay=1e-6)

    # 如果是读取之前训练的数据,则加载保存的优化器参数
    if config.from_old_model:
        G_optimizer.load_state_dict(torch.load(config.G_optimizer_path))
        D_optimizer_global.load_state_dict(torch.load(config.D_optimizer_global_path))
        D_optimizer_patch.load_state_dict(torch.load(config.D_optimizer_patch_path))

    # 损失函数
    criterion = config.criterion

    # 混合精度加速
    if config.use_apex:
        G_model, G_optimizer = amp.initialize(G_model, G_optimizer, opt_level="O1")
        D_model_global, D_optimizer_global = amp.initialize(D_model_global, D_optimizer_global, opt_level="O1")
        D_model_patch, D_optimizer_patch = amp.initialize(D_model_patch, D_optimizer_patch, opt_level="O1")

    # 记录训练时间
    train_start = time.time()

    # 定义标签,单值标签用于传统判别器,多值标签用于patch判别器
    # 定义真标签,使用标签平滑的策略,全0.9
    real_labels_global = Variable(torch.ones(config.batchSize, 1)-0.1).to(device)
    real_labels_patch = Variable(torch.ones(config.batchSize, config.patch_label_length)-0.1).to(device)

    # 定义假标签,单向平滑,因此不对生成器标签进行平滑处理,全0
    fake_labels_global = Variable(torch.zeros(config.batchSize, 1)).to(device)
    fake_labels_patch = Variable(torch.zeros(config.batchSize, config.patch_label_length)).to(device)

    # 开始训练的每一个epoch
    for epoch in range(config.epochs):
        print("start epoch "+str(epoch+1)+":")
        # 定义一些变量用于记录进度和损失
        batch_num = len(train_loader)
        D_loss_sum_global = 0
        D_loss_sum_patch = 0
        G_loss_sum = 0
        count = 0

        # 从dataloader中提取数据
        for index, images in enumerate(train_loader):
            count += 1
            # 将图片放入运算设备的内存
            images = images.to(device)

            # 记录真假标签是否被交换过
            exchange_labels = False

            # 有一定概率在训练判别器时交换label
            if random.uniform(0, 1) < config.D_train_label_exchange:
                real_labels_global, fake_labels_global = fake_labels_global, real_labels_global
                real_labels_patch, fake_labels_patch = fake_labels_patch, real_labels_patch
                exchange_labels = True

            # 训练判断器D_global
            D_optimizer_global.zero_grad()
            # 将随机的初始数据喂入生成器生成假图像
            img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds)
            # 用真样本输入判别器
            real_output = D_model_global(add_noise(images))
            # 对于数据集末尾的数据,长度不够一个batch size时需要去除过长的真实标签
            if len(real_labels_global) > len(real_output):
                D_loss_real = criterion(real_output, real_labels_global[:len(real_output)])
            else:
                D_loss_real = criterion(real_output, real_labels_global)
            # 用假样本输入判别器
            fake_output = D_model_global(add_noise(fake_images))
            D_loss_fake = criterion(fake_output, fake_labels_global)
            # 将真样本与假样本损失相加,得到判别器的损失
            D_loss_global = D_loss_real + D_loss_fake
            D_loss_sum_global += D_loss_global.item()
            # 重置优化器
            D_optimizer_global.zero_grad()
            # 用损失更新判别器
            if config.use_apex:
                with amp.scale_loss(D_loss_global, D_optimizer_global) as scaled_loss:
                    scaled_loss.backward()
            else:
                D_loss_global.backward()
            D_optimizer_global.step()

            # 训练判断器D_patch
            D_optimizer_patch.zero_grad()
            # 将随机的初始数据喂入生成器生成假图像
            img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds)
            # 用真样本输入判别器
            real_output = D_model_patch(add_noise(images))
            # 对于数据集末尾的数据,长度不够一个batch size时需要去除过长的真实标签
            if len(real_labels_patch) > len(real_output):
                D_loss_real = criterion(real_output, real_labels_patch[:len(real_output)])
            else:
                D_loss_real = criterion(real_output, real_labels_patch)
            # 用假样本输入判别器
            fake_output = D_model_patch(add_noise(fake_images))
            D_loss_fake = criterion(fake_output, fake_labels_patch)
            # 将真样本与假样本损失相加,得到判别器的损失
            D_loss_patch = D_loss_real + D_loss_fake
            D_loss_sum_patch += D_loss_patch.item()
            # 重置优化器
            D_optimizer_patch.zero_grad()
            # 用损失更新判别器
            if config.use_apex:
                with amp.scale_loss(D_loss_patch, D_optimizer_patch) as scaled_loss:
                    scaled_loss.backward()
            else:
                D_loss_patch.backward()
            D_optimizer_patch.step()

            # 如果之前交换过真假标签,此时再换回来
            if exchange_labels:
                real_labels_global, fake_labels_global = fake_labels_global, real_labels_global
                real_labels_patch, fake_labels_patch = fake_labels_patch, real_labels_patch

            # 训练生成器G
            # 将随机种子数喂入生成器G生成假数据
            img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds)
            # 将假数据输入判别器
            fake_output_global = D_model_global(add_noise(fake_images))
            fake_output_patch = D_model_patch(add_noise(fake_images))
            # 将假数据的判别结果与真实标签对比得到损失
            G_loss_global = criterion(fake_output_global, real_labels_global)
            G_loss_patch = criterion(fake_output_patch, real_labels_patch)
            G_loss = G_loss_global + G_loss_patch
            G_loss_sum += G_loss.item()
            # 重置优化器
            G_optimizer.zero_grad()
            # 利用损失更新生成器G
            if config.use_apex:
                with amp.scale_loss(G_loss, G_optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                G_loss.backward()
            G_optimizer.step()

            # 打印程序工作进度
            if (index + 1) % 200 == 0:
                print("Epoch: %2d, Batch: %4d / %4d" % (epoch + 1, index + 1, batch_num))

        if (epoch+1) % config.save_step == 0:
            # 在每N个epoch结束时保存模型参数到磁盘文件
            torch.save(G_model.state_dict(), config.G_model_path)
            torch.save(D_model_global.state_dict(), config.D_model_global_path)
            torch.save(D_model_patch.state_dict(), config.D_model_patch_path)
            # 在每N个epoch结束时保存优化器参数到磁盘文件
            torch.save(G_optimizer.state_dict(), config.G_optimizer_path)
            torch.save(D_optimizer_global.state_dict(), config.D_optimizer_global_path)
            torch.save(D_optimizer_patch.state_dict(), config.D_optimizer_patch_path)
            # 保存历史训练总数
            with open(config.epoch_record_path, "wb") as file:
                pickle.dump(config.last_epoch_number + epoch + 1, file, 1)
            # 在每N个epoch结束时输出一组生成器产生的图片到输出文件夹
            img_seeds = torch.randn(config.batchSize, config.img_seed_dim).to(device)
            fake_images = G_model(img_seeds).cuda().data
            # 将假图像缩放到[0,1]的区间
            fake_images = 0.5 * (fake_images + 1)
            fake_images = fake_images.clamp(0, 1)
            # 连接所有生成的图片然后用自带的save_image()函数输出到磁盘文件
            fake_images = fake_images.view(-1, 3, config.img_size, config.img_size)
            save_image(fake_images, config.output_path+str(config.last_epoch_number + epoch + 1)+'.png')

        # 打印该epoch的损失,时间等数据用于参考
        print("D_loss_global:", round(D_loss_sum_global / count, 3))
        print("D_loss_patch:", round(D_loss_sum_patch / count, 3))
        print("G_loss:", round(G_loss_sum / count, 3))
        current_time = time.time()
        pass_time = int(current_time - train_start)
        time_string = str(pass_time // 3600) + " hours, " + str((pass_time % 3600) // 60) + " minutes, " + str(
            pass_time % 60) + " seconds."
        print("Time pass:", time_string)
        print()

    # 运行结束
    print("Done.")


if __name__ == '__main__':
    main()

3. 效果

100个epoch:
在这里插入图片描述
1000个epoch:
在这里插入图片描述
5000个epoch:
在这里插入图片描述
10000个epoch:
在这里插入图片描述
15000个epoch:
在这里插入图片描述
20000个epoch:
在这里插入图片描述

25000个epoch:

在这里插入图片描述
结论:虽然没有单纯只用传统判别器那么圆滑,仍然可以观察到微小的瑕疵,但至少在整体上的轮廓已经接近球体,而星球表面的细节也变得更加清晰了。

4. 趣图

在第8800个epoch时,右下角出现了非常有趣的结果:
在这里插入图片描述
在这里插入图片描述
如果我没记错的话,我是让你学习画星星不是怎么画人脸吧……【抖抖抖
而且这个发型总觉得看起来很眼熟的样子……【微妙

下一集:彩色星球图片生成3:代码改进(pytorch版)

  • 7
    点赞
  • 31
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值