彩色星球图片生成3:代码改进(pytorch版)


上一集: 彩色星球图片生成2:同时使用传统Gan判别器和马尔可夫判别器(pytorch版)

在上一集代码的基础上,进行了一些细节的修改以改进生成效果。

1. 修改

1.1 预处理缩放

用于预处理训练集图片的代码修改为:

import cv2
import os
from PIL import Image

# 数据集来源
img_path = "train_images/"

for path, dirs, files in os.walk(img_path, topdown=False):
    file_list = list(files)
for file in file_list:
    image_path = img_path + file
    img = cv2.imread(image_path, 1)
    # 裁剪为正方形
    bias = (img.shape[1] - img.shape[0]) // 2
    img = img[:, bias:bias+img.shape[0], :]
    (B, G, R) = cv2.split(img)
    # 颜色通道合并
    img = cv2.merge([R, G, B])
    # 使用Image的ANTIALIAS缩放算法
    img = Image.fromarray(img)
    img = img.resize((264, 264), Image.ANTIALIAS)
    img.save(image_path)

改进点:使用Image库的Image.ANTIALIAS参数进行图片缩放,预先缩放减少了图片在训练过程中缩放消耗的时间,同时Image库的高质量缩放算法能够在将大图像缩小到低分辨率时保留细节纹理,减少锯齿现象的出现,有利于训练模型学习细节纹理特征。

1.2 随机翻转

训练代码中dataset构建部分的代码修改为:

elif config.read_from == "Memory":
    class image_dataset(Dataset):
        def __init__(self, file_list, img_path, transform):
            self.imgs = []
            for file in file_list:
                image_path = img_path + file
                img = cv2.imread(image_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                self.transform = transform
                # 将所有原始图片保存在内存中
                self.imgs.append(img)

        def __getitem__(self, index):
            # 将transform操作修改为每次提取具体图片时进行
            img = self.imgs[index]
            img = self.transform(image=img)['image']
            return img

        def __len__(self):
            return len(self.imgs)
def get_transforms(img_size):
    # 缩放分辨率并转换到0-1之间
    return Compose(
        # 取消了Resize的部分,同时添加了0.5概率随机垂直翻转与水平翻转
        # 星球图片显然理应可以随意翻转,有效扩大了训练集的信息量
         [ HorizontalFlip(p=0.5),
         VerticalFlip(p=0.5),
         Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), max_pixel_value=255.0, p=1.0),
         ToTensorV2(p=1.0)]
    )

修改点:将原始图片保存在内存中,在每个epoch中分别随机水平和垂直翻转,再转换为Tensor数据类型加入运算,相当于增加了训练集的数量。

1.3 修改全局判别器

将模型部分中全局判别器的代码修改如下:

# 全局判别器,传统gan
class D_net_global(nn.Module):
    def __init__(self):
        super(D_net_global,self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=5, stride=3, padding=1, bias=False),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(512, 16, kernel_size=4, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(16),
            nn.LeakyReLU(0.2, True),
        )
        self.classifier = nn.Sequential(
            # 将下方的两行注释消除
            # nn.Linear(1024, 1024),
            # nn.ReLU(True),
            nn.Linear(1024, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        features = self.features(img)
        features = features.view(features.shape[0], -1)
        output = self.classifier(features)
        return output

修改点:清除了判别器末尾的全连接层,同时删除了与之匹配的relu层,最大程度保留了画面细节的信息。

1.4 修改进度打印

训练过程中的进度打印代码修改为:

# 打印程序工作进度
print("\rEpoch: %2d, Batch: %4d / %4d" % (epoch + 1, index + 1, batch_num), end="")

修改点:末尾换行符取消,每次打印时通过\r移动到行首,覆盖上一次print的内容,从而做到在同一行中实时刷新batch的工作进度,更加直观简洁。

2. 效果

改进后,训练了更多epoch,并且将训练集扩大到了128张,训练了约10000个epoch之后,生成的画面效果如下:
请添加图片描述
请添加图片描述
请添加图片描述
能够明显看到画面有了极大的提升,星球表面纹理清晰,网格现象大幅度减少,背景噪点减少。
最后放张训练过程的全家福【每张图片4x8,间隔为100epoch】:请添加图片描述

3. 总结

这次在画面细节和减少网格图上获得了较大的提升,同时生成器也学会了绘制银河图案,遗憾的是完全没有学会训练集中的行星星环,而是无视了所有的星环。
下一步改进计划考虑加入分类器等更为复杂的网络结构使图像生成更加多样化,仍然在试图实现中……

下一集:彩色星球图片生成4:转置卷积+插值缩放+卷积收缩(pytorch版)

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值