网络模型(VAE-变分自编码)

概念

无监督学习,用于压缩还原、在保留部分原始信息的基础上生成、换脸……

小明是个画家,有 A 和 B 两个人,看 A 两秒钟,看 B 一小时,然后画 A,画中会有 B 的特征。

编码:提取特征。(看 A)
解码:学习新特征,还原。(看 B 画 A)
VAE
编码结果向 N(0, 1) 靠近,因为真实分布和 N(0, 1) 是有差别的,所以解码生成的图片会模糊。

编码器得出的正态分布:N(μ, σ²)。
标准正态分布:N(0, 1)。
关系:N(μ, σ²) = N(0, 1) * σ² + μ。

编码损失:计算 N(μ, σ²) 和 N(0, 1) 差距,也就是 KL 散度。
KL(N(μ, σ²) || N(0, 1)) = (-logσ² + μ² + σ² - 1) / 2。

实验(手写数字图片压缩还原)

数据集:MNIST。

网络结构:

  • 编码:RNN + 标准化(BN)+ 激活(ReLU)。
  • 解码:转置卷积 + 标准化(BN)+ 激活(ReLU)。

优化器:Adam。

损失函数:均方差(MSELoss)。

输出:图片。

网络

import torch
from torch import nn


class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Sequential(
            # (28 * 28)→(14 * 14)
            nn.Conv2d(1, 128, 3, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
            # (14 * 14)→(7 * 7)
            nn.Conv2d(128, 256, 3, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
            # (7 * 7)→(3 * 3)
            nn.Conv2d(256, 512, 3, 2, 0), nn.BatchNorm2d(512), nn.ReLU(),
            # (3 * 3)→(1 * 1)
            nn.Conv2d(512, 2, 3, 1, 0)
        )

    def forward(self, x):
        # [n,2,1,1]
        out = self.conv(x)
        # 取 c 轴第一个作为 μ
        miu = out[:, :1, :, :]
        # 取 c 轴第二个作为 logσ²
        log_sigma = out[:, 1:, :, :]
        return miu, log_sigma


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_t = nn.Sequential(
            nn.ConvTranspose2d(128, 512, 3, 1, 0), nn.BatchNorm2d(512), nn.ReLU(),
            nn.ConvTranspose2d(512, 256, 3, 2, 0), nn.BatchNorm2d(256), nn.ReLU(),
            # 尺寸除不尽时,要加 output_padding = 1
            nn.ConvTranspose2d(256, 128, 3, 2, 1, 1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128, 1, 3, 2, 1, 1)
        )

    def forward(self, miu, log_sigma, z):
        # σ² = e ** logσ²
        x = z * torch.exp(log_sigma) + miu
        # [n,h,w,c] → [n,c,h,w]
        x = x.permute(0, 3, 1, 2)
        return self.conv_t(x)


class MyNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.encoder = Encoder()
        self.decoder = Decoder()

    def forward(self, x, z):
        miu, log_sigma = self.encoder(x)
        out = self.decoder(miu, log_sigma, z)
        return miu, log_sigma, out

训练

import torch
from torch import nn
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.utils.data import DataLoader
import os

from net import MyNet


batch_size = 100
net_path = r"modules/mynet.pth"

train_flag = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5], [0.5])])
if train_flag:
    dataset = datasets.MNIST(r"data", train=True, transform=transform, download=True)
    dataloader = DataLoader(dataset, batch_size, shuffle=True)
else:
    dataset = datasets.MNIST(r"data", train=False, transform=transform, download=False)
    dataloader = DataLoader(dataset, batch_size, shuffle=False)


if __name__ == '__main__':
    # 加载网络
    if os.path.isfile(net_path):
        net = torch.load(net_path).to(device)
    else:
        net = MyNet().to(device)
    opt = torch.optim.Adam(net.parameters())
    # 只在 bitch 上做平均
    loss_fn = nn.MSELoss(reduction='sum')

    if train_flag:
        # 训练
        net.train()
        while True:
            for i, (x, y) in enumerate(dataloader):
                x = x.to(device)
                # 从标准正态分布中取值
                z = torch.randn(128).to(device)
                miu, log_sigma, out = net(x, z)

                # 编码损失(N(μ,σ²) 和 N(0,1))
                loss_e = torch.mean((-log_sigma + torch.exp(log_sigma) + miu ** 2 - 1) / 2)
                # 解码损失(生成的图和原图)
                loss_d = loss_fn(out, x)
                loss = loss_e + loss_d

                opt.zero_grad()
                loss.backward()
                opt.step()
                print("i:{},loss:{:.5}".format(i, loss))
            # 保存网络
            torch.save(net, net_path)
    else:
        # 测试
        net.eval()
        for i, (x, y) in enumerate(dataloader):
            x = x.to(device)
            # 从标准正态分布中取值
            z = torch.randn(128).to(device)
            miu, log_sigma, out = net(x, z)

            # 保存图片
            fake_img = out.data
            img = x.data
            save_image(fake_img, "./img/{}-fake_img.png".format(i), nrow=10)
            save_image(img, "./img/{}-real_img.png".format(i), nrow=10)
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值