【图像生成】(一) DNN 原理 & pytorch代码实例

1.简介

从之前开始就一直想做这个【图像生成】系列,主要是之前做过很多跟GAN相关的项目,同时AIGC和Stable Diffusion的火爆也开始让这个方向的研究活跃起来。

这个板块主要是想介绍【图像生成】模型的发展过程,以及其中最主要的几个代表模型,来研究和分析其中的改进点和创新性,来更好地帮助大家理解图像生成模型。

目前想到的有以下4个模型:

  • DNN
  • GAN
  • VAE
  • Diffusion

数据集我们就用最简单的MNIST手写数字数据集,来做基本的生成和condition生成任务,这样也方便更好理解模型的流程和思路,每个模型都会介绍它们的原理、公式以及对应的pytorch代码实现。如果对其他的任务比如cycleGAN(图生图)、Stable Diffusion(顶流)、ControlNet(类似结构)等感兴趣的,可以留言评论或者联系我,可以考虑后续再出相应的介绍。


2.原理

用什么模型来生成图像呢?首先想到的肯定是最质朴的神经网络DNN。当然这个方法在实际中不会用到,因为它存在很大的缺陷,具体什么缺陷后面就可以看出来。这里只是为了抛砖引玉。

现在梳理下整体的思路,我们需要输入一个标签(范围从0到9,表示我们想要生成的数字类别),然后需要输出一个对应标签的图像。自然而然就出现以下pipeline: 

其中labels为数字的对应标签,通过embedding先转换为特征向量(这里的embedding可以使用最简单的ont_hot),然后使用Linear对特征进行提取和通道转换,最后经过reshape转换为二维后再输入进ConvTrans2d,生成最后的二维图像。

为了对模型进行训练,我们使用对应的图像和preds输出基于MSE计算损失(最直观的损失计算方法,但存在弊端),同时为了提高生成图像的随机性,我们在输入端加入了随机latent,与labels的embedding向量进行拼接。因此,最终的模型pipeline如下:


3.代码

OK,接下来我们就用pytorch来实现以上pipeline。

3.1模型

首先我们实现DNN网络,变量概念和基本流程都已经在代码中给出。需要注意的是网络的输入应该是latent和labels embedding向量拼接后的维度,输出维度是1(单通道灰度图像),网络最后需要加上Sigmoid来使输出范围在0-1之间。同时Linear的输出和ConvTrans2d的输入需要设计和匹配来实现输出的图像大小与原图大小相同。

class DNN(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, class_num=10):
        '''
        初始化网络
        :param input_dim:输入维度,也是latent维度
        :param output_dim:输出维度,表示最终生成图片的通道数
        :param class_num:图像种类,代表condition种类
        '''
        super(DNN, self).__init__()
        # 网络的输入是latent的维度拼接上condition向量的维度
        self.input_dim = input_dim + class_num
        self.output_dim = output_dim

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Sigmoid(),
        )

    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, 7, 7)
        x = self.deconv(x)
        return x

3.2数据集

数据集直接从torchvision调用现成的MNIST数据集函数,并且通过dataloader进行包装。

    def init_dataloader(self):
        '''
        初始化数据集和dataloader
        '''
        tf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        train_dataset = MNIST('./data/',
                              train=True,
                              download=True,
                              transform=tf)
        self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        val_dataset = MNIST('./data/',
                            train=False,
                            download=True,
                            transform=tf)
        self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        self.output_dim = self.train_dataloader.__iter__().__next__()[0].shape[1]

3.3训练

训练pipeline也是一个正常流程,将labels进行one_hot编码后与latent进行拼接输入网络,然后再将网络输出与原图计算损失。

    def train(self):
        self.model.train()
        print('训练开始!!')
        for epoch in range(self.epoch):
            self.model.train()
            loss_mean = 0
            for i, (images, labels) in enumerate(self.train_dataloader):
                # 生成对应batch和维度的latent
                z = torch.rand((self.batch_size, self.z_dim)).to(self.device)
                images, labels = images.to(self.device), labels.to(self.device)
                # 将原始label做one hot后作为condition向量
                labels = F.one_hot(labels, num_classes=10)
                self.optimizer.zero_grad()
                # 将latent和condition拼接后输入网络
                generated_images = self.model(torch.cat((z, labels), dim=1))
                loss = self.loss(generated_images, images)
                loss_mean += loss.item()
                loss.backward()
                self.optimizer.step()
            train_loss = loss_mean / len(self.train_dataloader)
            val_loss = self.evaluation()
            print('epoch:{}, training loss:{:.4f}, validation loss:{:.4f}'.format(epoch, train_loss, val_loss))
            self.visualize_results(epoch)

3.4推理&可视化

最后我们使用训练好的模型进行推理和可视化。我们随机生成100个sample,然后10个类别每个类别分别占10个sample。最后生成的结果如下:

@torch.no_grad()
    def visualize_results(self, epoch):
        self.model.eval()
        # 保存结果路径
        output_path = 'results/DNN'
        if not os.path.exists(output_path):
            os.makedirs(output_path)

        tot_num_samples = self.sample_num
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        z = torch.rand((tot_num_samples, self.z_dim)).to(self.device)
        # 生成对应sample个condition向量,每十个sample为一类
        labels = F.one_hot(torch.Tensor(np.repeat(np.arange(10), 10)).to(torch.int64), num_classes=10).to(self.device)
        generated_images = self.model(torch.cat((z, labels), dim=1))
        save_image(generated_images, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)

这样的话每个类别都会生成对应的图片,但是有没有发现一个问题?就是每个生成的图片长得太像了。

我个人的理解而言,由于MSE损失计算的是基于像素间的差别,所以生成的图像只会与大部分的典型的图像相似,就算加入了随机latent,在模型不断的收敛过程中latent部分的输出会尽可能接近0,来保证输出结果对典型部分数据的相似性。所以说才会造成不同随机latent的生成图像都很相似的情况。

完整代码如下:

import torch, time, os
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
import torch.nn.functional as F


class DNN(nn.Module):
    def __init__(self, input_dim=100, output_dim=1, class_num=10):
        '''
        初始化网络
        :param input_dim:输入维度,也是latent维度
        :param output_dim:输出维度,表示最终生成图片的通道数
        :param class_num:图像种类,代表condition种类
        '''
        super(DNN, self).__init__()
        # 网络的输入是latent的维度拼接上condition向量的维度
        self.input_dim = input_dim + class_num
        self.output_dim = output_dim

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * 7 * 7),
            nn.BatchNorm1d(128 * 7 * 7),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Sigmoid(),
        )

    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, 7, 7)
        x = self.deconv(x)
        return x


class ImageGenerator(object):
    def __init__(self):
        '''
        初始化,定义超参数、数据集、网络结构等
        '''
        self.epoch = 5
        self.sample_num = 100
        self.batch_size = 64
        self.z_dim = 62
        self.lr = 0.0001
        self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
        self.init_dataloader()
        self.model = DNN(input_dim=self.z_dim, output_dim=self.output_dim).to(self.device)
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.lr)
        self.loss = nn.MSELoss().to(self.device)

    def init_dataloader(self):
        '''
        初始化数据集和dataloader
        '''
        tf = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.1307,), (0.3081,))
        ])
        train_dataset = MNIST('./data/',
                              train=True,
                              download=True,
                              transform=tf)
        self.train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        val_dataset = MNIST('./data/',
                            train=False,
                            download=True,
                            transform=tf)
        self.val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, shuffle=False)
        self.output_dim = self.train_dataloader.__iter__().__next__()[0].shape[1]

    def train(self):
        self.model.train()
        print('训练开始!!')
        for epoch in range(self.epoch):
            self.model.train()
            loss_mean = 0
            for i, (images, labels) in enumerate(self.train_dataloader):
                # 生成对应batch和维度的latent
                z = torch.rand((self.batch_size, self.z_dim)).to(self.device)
                images, labels = images.to(self.device), labels.to(self.device)
                # 将原始label做one hot后作为condition向量
                labels = F.one_hot(labels, num_classes=10)
                self.optimizer.zero_grad()
                # 将latent和condition拼接后输入网络
                generated_images = self.model(torch.cat((z, labels), dim=1))
                loss = self.loss(generated_images, images)
                loss_mean += loss.item()
                loss.backward()
                self.optimizer.step()
            train_loss = loss_mean / len(self.train_dataloader)
            val_loss = self.evaluation()
            print('epoch:{}, training loss:{:.4f}, validation loss:{:.4f}'.format(epoch, train_loss, val_loss))
            self.visualize_results(epoch)

    @torch.no_grad()
    def evaluation(self):
        self.model.eval()
        loss_mean = 0
        for i, (images, labels) in enumerate(self.val_dataloader):
            # 生成对应image batch和维度的latent
            z = torch.rand((images.shape[0], self.z_dim)).to(self.device)
            images, labels = images.to(self.device), labels.to(self.device)
            # 将原始label做one hot后作为condition向量
            labels = F.one_hot(labels, num_classes=10)
            # 将latent和condition拼接后输入网络
            generated_images = self.model(torch.cat((z, labels), dim=1))
            loss = self.loss(generated_images, images)
            loss_mean += loss.item()
        return loss_mean / len(self.val_dataloader)

    @torch.no_grad()
    def visualize_results(self, epoch):
        self.model.eval()
        # 保存结果路径
        output_path = 'results/DNN'
        if not os.path.exists(output_path):
            os.makedirs(output_path)

        tot_num_samples = self.sample_num
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        z = torch.rand((tot_num_samples, self.z_dim)).to(self.device)
        # 生成对应sample个condition向量,每十个sample为一类
        labels = F.one_hot(torch.Tensor(np.repeat(np.arange(10), 10)).to(torch.int64), num_classes=10).to(self.device)
        generated_images = self.model(torch.cat((z, labels), dim=1))
        save_image(generated_images, os.path.join(output_path, '{}.jpg'.format(epoch)), nrow=image_frame_dim)


if __name__ == '__main__':
    generator = ImageGenerator()
    generator.train()

4.总结

OK我们知道了朴素DNN在【图像生成】这块并不好用,下一篇文章里我会介绍GAN是如何改进这个问题。


业务合作/学习交流+v:lizhiTechnology

 如果想要了解更多图像生成相关知识,可以参考我的专栏和其他相关文章:

图像生成_Lcm_Tech的博客-CSDN博客

【图像生成】(一) DNN 原理 & pytorch代码实例_pytorch dnn代码-CSDN博客

【图像生成】(二) GAN 原理 & pytorch代码实例_gan代码-CSDN博客

【图像生成】(三) VAE原理 & pytorch代码实例_vae算法 是如何生成图的-CSDN博客

【图像生成】(四) Diffusion原理 & pytorch代码实例_diffusion unet-CSDN博客

如果想要了解更多深度学习相关知识,可以参考我的其他文章:

深度学习_Lcm_Tech的博客-CSDN博客

【优化器】(一) SGD原理 & pytorch代码解析_sgd优化器-CSDN博客

【损失函数】(一) L1Loss原理 & pytorch代码解析_l1 loss-CSDN博客

【diffusers】(一) diffusers库介绍 & 框架代码解析-CSDN博客

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值