Glow-pytorch复现github项目

在完成Glow论文学习后,对github上一个复现的仓库进行学习,帮助理解算法实现过程中的一些细节;所选择的复现仓库是基于pytorch实现,链接为https://github.com/rosinality/glow-pytorch。Glow是基于Flow的模型的,其结构很直接,数学原理确定,在定义模块时,要保证模型可以逆向运算。

本仓库中的代码主要在model.py和train.py中,结构、逻辑很清晰,具体实现与论文一致,阅读论文时可直接与代码比对学习,感兴趣的读者可阅读Normalized Glow论文阅读笔记,主要搞清楚其中的表1即可对理解整个代码构造。先主要对上述两个py文件进行注释解析,帮助学习和理解。

model.py

清晰的对论文图2中展示的Glow的各个模块进行定义

import torch
from torch import nn
from torch.nn import functional as F
from math import log, pi, exp
import numpy as np
from scipy import linalg as la

device = "gpu:0" if torch.cuda.is_available() else "cpu"

logabs = lambda x: torch.log(torch.abs(x))  # 自定义求log(|x|)的lambda函数


class ActNorm(nn.Module):
    def __init__(self, in_channel, logdet=True):
        super().__init__()
        # 只对channel维度进行原酸,即本质上loc、scale只有num_channels个数值
        self.loc = nn.Parameter(torch.zeros(1, in_channel, 1, 1))  # 相当于论文表1中的b,平移量
        self.scale = nn.Parameter(torch.ones(1, in_channel, 1, 1))  # 相当于论文表1中的s,伸缩量

        # self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
        # 用于判断是否是第一个batch,如果是,就调用initialize函数计算该batch内数据的均值和方差,再初始化loc和scale
        self.initialized = nn.Parameter(torch.tensor(0, dtype=torch.uint8), requires_grad=False)

        self.logdet = logdet

    def initialize(self, input):
        with torch.no_grad():
            # flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
            # mean = (
            #     flatten.mean(1)
            #         .unsqueeze(1)
            #         .unsqueeze(2)
            #         .unsqueeze(3)
            #         .permute(1, 0, 2, 3)
            # )
            # std = (
            #     flatten.std(1)
            #         .unsqueeze(1)
            #         .unsqueeze(2)
            #         .unsqueeze(3)
            #         .permute(1, 0, 2, 3)
            # )

            # 此处的mean,std可以通过torch.mean(input, dim=(0, 2, 3), keepdim=True)来实现
            mean = torch.mean(input, dim=(0, 2, 3), keepdim=True)
            std = torch.std(input, dim=(0, 2, 3), keepdim=True)

            # 论文中提到的数据以来的初始化 ,目的是第一个batch的数据经过actnorm后变成标准分布,以稳定训练收敛
            self.loc.data.copy_(-mean)
            self.scale.data.copy_(1 / (std + 1e-6))

    def forward(self, input):  # 前向过程
        _, _, height, width = input.shape  # 整体的shape是[batch_size, num_channel, height, width]

        if self.initialized.item() == 0:  # 最开始为初始化的0,表示第一个batch,要进行数据依赖的初始化;完成后,用1填充,后续就不再执行
            self.initialize(input)
            self.initialized.fill_(1)

        log_abs = logabs(self.scale)  # 对scake中的每个元素值求log

        logdet = height * width * torch.sum(log_abs)  # 对数似然的变化量,是一个标量

        if self.logdet:  # 是否返回对数似然的增量
            return self.scale * (input + self.loc), logdet

        else:
            return self.scale * (input + self.loc)

    def reverse(self, output):  # 逆向过程,推理
        return output / self.scale - self.loc


# 可逆1X1二维卷积
class InvConv2d(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        weight = torch.randn(in_channel, in_channel)  # 初始化1X1卷积的权重参数,尺寸是[c, c],就相当于在通道上的一个MLP
        q, _ = torch.qr(weight)  # 对weight进行qr分解,q是一个正交的矩阵,即q的行列式不为0
        # 尺寸扩为[c, c, 1, 1],因为在conv2d中要求的weight的尺寸为[input_size, output_size, kernel_size, kernel_size]
        weight = q.unsqueeze(2).unsqueeze(3)
        self.weight = nn.Parameter(weight)

    def forward(self, input):  # 前向
        _, _, height, width = input.shape

        out = F.conv2d(input, self.weight)
        logdet = height * width * torch.slogdet(self.weight.squeeze().double())[1].float()

        return out, logdet

    def reverse(self, output):  # 逆向
        return F.conv2d(output, self.weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))


# 使用LU分解的快速版可逆二维卷积
class InvConv2dLU(nn.Module):
    def __init__(self, in_channel):
        super().__init__()

        weight = np.random.randn(in_channel, in_channel)
        q, _ = la.qr(weight)  # 任何矩阵都可进行qr分解 ,分解后q是正交矩阵

        # 行列式不为0,LU分解一定存在;q为正交矩阵,行列式值为+1或-1,故一定可以进行LU分解
        # A = P L U,P为置换矩阵,L为下三角、U为上三角;L的对角线元素一定为1
        w_p, w_l, w_u = la.lu(q.astype(np.float32))
        w_s = np.diag(w_u)  # 获取对角线元素构成数组,一维向量
        w_u = np.triu(w_u, 1)  # 只保留第一条对角线的上三角矩阵,其实就是对角线元素表内0
        u_mask = np.triu(np.ones_like(w_u), 1)  # 只保留第一条对角线的上三角矩阵,对应元素全部为1,其他元素变为了0
        l_mask = u_mask.T  # 只保留第-1条对角线的下三角矩阵,对应元素为1,其他元素全部为0
        # 将上述定义的向量转为张量
        w_p = torch.from_numpy(w_p).to(device)
        w_l = torch.from_numpy(w_l).to(device)
        w_s = torch.from_numpy(w_s).to(device)
        w_u = torch.from_numpy(w_u).to(device)

        # 将不需要更新的量,将其注册为buffer
        self.register_buffer("w_p", w_p)
        self.register_buffer("u_mask", torch.from_numpy(u_mask))
        self.register_buffer("l_mask", torch.from_numpy(l_mask))
        self.register_buffer("s_sign", torch.sign(w_s))  # 记录w_s的符号
        self.register_buffer("l_eye", torch.eye(l_mask.shape[0]))
        # 需要更新的量
        self.w_l = nn.Parameter(w_l)
        self.w_s = nn.Parameter(logabs(w_s))
        self.w_u = nn.Parameter(w_u)

    def forward(self, input):  # 前向
        _, _, height, width = input.shape

        weight = self.calc_weight()  # 基于分解之后的多个值重构weight

        out = F.conv2d(input, weight)  # 卷积计算
        logdet = height * width * torch.sum(self.w_s)

        return out, logdet

    def calc_weight(self):
        weight = (
                self.w_p
                @ (self.w_l * self.l_mask + self.l_eye)  # 保证对角线元素为1
                @ ((self.w_u * self.u_mask) + torch.diag(self.s_sign * torch.exp(self.w_s)))  # 因为前面已经去了logabs,就先用exp再乘上符号
        )

        return weight.unsqueeze(2).unsqueeze(3)

    def reverse(self, output):  # 逆向
        weight = self.calc_weight()

        return F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))


# 全零初始化二维卷积
class ZeroConv2d(nn.Module):
    def __init__(self, in_channel, out_channel, padding=1):
        super().__init__()

        self.conv = nn.Conv2d(in_channel, out_channel, 3, padding=0)
        self.conv.weight.data.zero_()  # 权重为0
        self.conv.bias.data.zero_()  # 偏置为0
        # 尺寸为[batch_size, output_channel, h, w],相当于一个可以学习的channel scale,是对每个通道进行单独缩放,也是全零初始化
        self.scale = nn.Parameter(torch.zeros(1, out_channel, 1, 1))

    def forward(self, input):
        out = F.pad(input, [1, 1, 1, 1], value=1)
        out = self.conv(out)
        out = out * torch.exp(self.scale * 3)

        return out


# 仿射耦合层
class AffineCoupling(nn.Module):
    def __init__(self, in_channel, filter_size=512, affine=True):
        super().__init__()

        self.affine = affine

        # 就是论文表1中的NN()非线性变换,就是一个模型
        self.net = nn.Sequential(
            nn.Conv2d(in_channel // 2, filter_size, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(filter_size, filter_size, 1),
            nn.ReLU(inplace=True),
            # 全零初始化,此网络一开始可以当作一个恒等变换(与论文中一致);affine是为了保证输出量分布,分别是logs和l
            ZeroConv2d(filter_size, in_channel if self.affine else in_channel // 2),
        )
        # net中的第0层即为第一个卷积层
        self.net[0].weight.data.normal_(0, 0.05)
        self.net[0].bias.data.zero_()
        # net中的第2层即为第二个卷积层
        self.net[2].weight.data.normal_(0, 0.05)
        self.net[2].bias.data.zero_()

    def forward(self, input):  # 前向
        in_a, in_b = input.chunk(2, 1)  # 将input在通道维度上分割为两部分,就是论文表1中的x_a, x_b

        if self.affine:  # 如果进行仿射
            log_s, t = self.net(in_a).chunk(2, 1)  # 经过网络输出log_s和t
            # s = torch.exp(log_s)
            s = F.sigmoid(log_s + 2)  # 论文表1中是取指数,但多多个github项目是直接使用sigmoid函数效果更好,训练更稳定
            # out_a = s * in_a + t
            out_b = (in_b + t) * s

            # s的形状为[batch_size, output_channel//2, h, w];因为此处的s与batch_szie有关,所以logdet形状也为batch_szie,不再是标量
            logdet = torch.sum(torch.log(s).view(input.shape[0], -1), 1)
        else:
            net_out = self.net(in_a)
            out_b = in_b + net_out
            logdet = None  # 没有对数似然增量

        return torch.cat([in_a, out_b], 1), logdet  # 将in_a和out_b拼接起来

    def reverse(self, output):  # 逆向
        out_a, out_b = output.chunk(2, 1)  # 通道维度分成两部分

        if self.affine:
            log_s, t = self.net(out_a).chunk(2, 1)
            # s = torch.exp(log_s)
            s = F.sigmoid(log_s + 2)
            # in_a = (out_a - t) / s
            in_b = out_b / s - t
        else:
            net_out = self.net(out_a)
            in_b = out_b - net_out

        return torch.cat([out_a, in_b], 1)


class Flow(nn.Module):
    def __init__(self, in_channel, affine=True, conv_lu=True):
        super().__init__()

        self.actnorm = ActNorm(in_channel)

        if conv_lu:
            self.invconv = InvConv2dLU(in_channel)

        else:
            self.invconv = InvConv2d(in_channel)

        self.coupling = AffineCoupling(in_channel, affine=affine)

    def forward(self, input):  # 前向
        out, logdet = self.actnorm(input)
        out, det1 = self.invconv(out)
        out, det2 = self.coupling(out)

        logdet = logdet + det1
        if det2 is not None:
            logdet = logdet + det2

        return out, logdet

    def reverse(self, output):  # 逆向,先经过耦合层,再经过卷积层,最后经过actnorm层
        input = self.coupling.reverse(output)
        input = self.invconv.reverse(input)
        input = self.actnorm.reverse(input)

        return input


# 高斯分布的概率密度函数再取log
def gaussian_log_p(x, mean, log_sd):
    return -0.5 * log(2 * pi) - log_sd - 0.5 * (x - mean) ** 2 / torch.exp(2 * log_sd)


def gaussian_sample(eps, mean, log_sd):
    return mean + torch.exp(log_sd) * eps


class Block(nn.Module):
    '''每个block包含多个flow'''
    def __init__(self, in_channel, n_flow, split=True, affine=True, conv_lu=True):
        super().__init__()

        squeeze_dim = in_channel * 4

        self.flows = nn.ModuleList()
        for i in range(n_flow):
            self.flows.append(Flow(squeeze_dim, affine=affine, conv_lu=conv_lu))

        self.split = split

        if split:  # RealNVP论文中的mutil-scale结构,随着层数的深入,计算的复杂的越来越小
            self.prior = ZeroConv2d(in_channel * 2, in_channel * 4)

        else:
            self.prior = ZeroConv2d(in_channel * 4, in_channel * 8)

    def forward(self, input):
        b_size, n_channel, height, width = input.shape
        # 对输入的通道数进行扩增,对空间进行缩小,即通道数扩增至四倍,长和宽都变为原来一半;就是论文中图2b中的squeeze操作
        squeezed = input.view(b_size, n_channel, height // 2, 2, width // 2, 2)
        squeezed = squeezed.permute(0, 1, 3, 5, 2, 4)
        out = squeezed.contiguous().view(b_size, n_channel * 4, height // 2, width // 2)

        logdet = 0

        for flow in self.flows:  # 经过多个flow,输入输出形状一致
            out, det = flow(out)
            logdet = logdet + det

        if self.split:  # 如果进行mutil-scale,输出的一半直接设置为最后的输出,剩下的一半进入下一个f变换
            out, z_new = out.chunk(2, 1)  # 此处的z_new就是本次block的获得的最终的一半的z
            mean, log_sd = self.prior(out).chunk(2, 1)
            log_p = gaussian_log_p(z_new, mean, log_sd)
            log_p = log_p.view(b_size, -1).sum(1)
        else:  # 所有的out进入下一个f变换
            zero = torch.zeros_like(out)
            mean, log_sd = self.prior(zero).chunk(2, 1)
            log_p = gaussian_log_p(out, mean, log_sd)
            log_p = log_p.view(b_size, -1).sum(1)
            z_new = out

        return out, logdet, log_p, z_new

    def reverse(self, output, eps=None, reconstruct=False):
        input = output

        if reconstruct:
            if self.split:
                input = torch.cat([output, eps], 1)

            else:
                input = eps

        else:
            if self.split:
                mean, log_sd = self.prior(input).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = torch.cat([output, z], 1)

            else:
                zero = torch.zeros_like(input)
                # zero = F.pad(zero, [1, 1, 1, 1], value=1)
                mean, log_sd = self.prior(zero).chunk(2, 1)
                z = gaussian_sample(eps, mean, log_sd)
                input = z

        for flow in self.flows[::-1]:
            input = flow.reverse(input)

        b_size, n_channel, height, width = input.shape

        unsqueezed = input.view(b_size, n_channel // 4, 2, 2, height, width)
        unsqueezed = unsqueezed.permute(0, 1, 4, 2, 5, 3)
        unsqueezed = unsqueezed.contiguous().view(b_size, n_channel // 4, height * 2, width * 2)

        return unsqueezed


class Glow(nn.Module):
    def __init__(self, in_channel, n_flow, n_block, affine=True, conv_lu=True):
        super().__init__()

        self.blocks = nn.ModuleList()
        n_channel = in_channel
        for i in range(n_block - 1):
            self.blocks.append(Block(n_channel, n_flow, affine=affine, conv_lu=conv_lu))
            n_channel *= 2
        self.blocks.append(Block(n_channel, n_flow, split=False, affine=affine))  # 最后一层不进行split

    def forward(self, input):
        log_p_sum = 0
        logdet = 0
        out = input
        z_outs = []

        for block in self.blocks:
            out, det, log_p, z_new = block(out)
            z_outs.append(z_new)  # 搜集每个block生成的z
            logdet = logdet + det  # 累计对数行列式

            if log_p is not None:
                log_p_sum = log_p_sum + log_p  # 累计对数似然的增量

        return log_p_sum, logdet, z_outs

    def reverse(self, z_list, reconstruct=False):
        for i, block in enumerate(self.blocks[::-1]):
            if i == 0:  # 最后一个block,z=output=input
                input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)

            else:   # 前面每个block的input=output
                input = block.reverse(input, z_list[-(i + 1)], reconstruct=reconstruct)

        return input

train.py

完成训练过程的代码编写,其中模型输入除image外,还增加了随机噪声训练过程损失的计算需要注意。整个过程就是先使用图片正向训练模型,完成训练后,使用随机采样的z逆向送入模型中进行图片生成(就是调用模型的reverse函数)

from tqdm import tqdm
import numpy as np
from PIL import Image
from math import log, sqrt, pi

import argparse

import torch
from torch import nn, optim
from torch.autograd import Variable, grad
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils

from model import Glow

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

parser = argparse.ArgumentParser(description="Glow trainer")
parser.add_argument("--batch", default=16, type=int, help="batch size")
parser.add_argument("--iter", default=200000, type=int, help="maximum iterations")  # 迭代周期
parser.add_argument("--n_flow", default=32, type=int, help="number of flows in each block")  # block中flow的个数
parser.add_argument("--n_block", default=4, type=int, help="number of blocks")  # Glow中flow的个数
parser.add_argument(
    "--no_lu",
    action="store_true",
    help="use plain convolution instead of LU decomposed version",)  # 是否使用LU分解计算仿射耦合层的对数似然增量
parser.add_argument("--affine", action="store_true", help="use affine coupling instead of additive")  # 是否进行仿射
parser.add_argument("--n_bits", default=5, type=int, help="number of bits")
parser.add_argument("--lr", default=1e-4, type=float, help="learning rate")
parser.add_argument("--img_size", default=64, type=int, help="image size")
parser.add_argument("--temp", default=0.7, type=float, help="temperature of sampling")
parser.add_argument("--n_sample", default=20, type=int, help="number of samples")
parser.add_argument("path", metavar="PATH", type=str, help="Path to image directory")


def sample_data(path, batch_size, image_size):
    # 定义图形转换的transformer
    transform = transforms.Compose(
        [
            transforms.Resize(image_size),
            transforms.CenterCrop(image_size),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    # 构建数据集
    dataset = datasets.ImageFolder(path, transform=transform)
    loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
    loader = iter(loader)

    while True:
        try:
            yield next(loader)

        except StopIteration:
            loader = DataLoader(dataset, shuffle=True, batch_size=batch_size, num_workers=4)
            loader = iter(loader)
            yield next(loader)


# 计算每个block中z的形状
def calc_z_shapes(n_channel, input_size, n_flow, n_block):
    z_shapes = []

    for i in range(n_block - 1):
        input_size //= 2
        n_channel *= 2

        z_shapes.append((n_channel, input_size, input_size))

    input_size //= 2
    z_shapes.append((n_channel * 4, input_size, input_size))

    return z_shapes


def calc_loss(log_p, logdet, image_size, n_bins):
    # log_p = calc_log_p([z_list])
    n_pixel = image_size * image_size * 3

    loss = -log(n_bins) * n_pixel  # 添加的均匀分布噪声的似然
    loss = loss + logdet + log_p

    return (
        (-loss / (log(2) * n_pixel)).mean(),  # NLL bit per dimension,bits/dim;除log(2)是进行换底,将e对数换为2对数
        (log_p / (log(2) * n_pixel)).mean(),
        (logdet / (log(2) * n_pixel)).mean(),
    )


def train(args, model, optimizer):
    dataset = iter(sample_data(args.path, args.batch, args.img_size))  # 加载图像数据集
    n_bins = 2.0 ** args.n_bits  # 设置bit位

    z_sample = []
    z_shapes = calc_z_shapes(3, args.img_size, args.n_flow, args.n_block)  # 计算存储block中z的形状
    for z in z_shapes:
        z_new = torch.randn(args.n_sample, *z) * args.temp
        z_sample.append(z_new.to(device))

    with tqdm(range(args.iter)) as pbar:  # 开始迭代循环
        for i in pbar:
            image, _ = next(dataset)
            image = image.to(device)

            # 图形量化
            image = image * 255
            if args.n_bits < 8:
                image = torch.floor(image / 2 ** (8 - args.n_bits))
            image = image / n_bins - 0.5  # 转为[-0.5, 0.5]的值域

            # 模型的数据输入除了image外,还添加了随机噪声
            if i == 0:  # 第0步进行数据依赖的参数初始化,故没有进行数据更新
                with torch.no_grad():
                    log_p, logdet, _ = model.module(
                        image + torch.rand_like(image) / n_bins
                    )

                    continue
            else:
                log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins)

            logdet = logdet.mean()

            loss, log_p, log_det = calc_loss(log_p, logdet, args.img_size, n_bins)  # 计算损失
            model.zero_grad()
            loss.backward()
            # warmup_lr = args.lr * min(1, i * batch_size / (50000 * 10))
            warmup_lr = args.lr
            optimizer.param_groups[0]["lr"] = warmup_lr
            optimizer.step()

            pbar.set_description(
                f"Loss: {loss.item():.5f}; logP: {log_p.item():.5f}; logdet: {log_det.item():.5f}; lr: {warmup_lr:.7f}"
            )

            if i % 100 == 0:  # 每100个batch进行一次采样
                with torch.no_grad():
                    utils.save_image(
                        model_single.reverse(z_sample).cpu().data,  # 使用随机采样的z逆向送入model中计算生成图像
                        f"sample/{str(i + 1).zfill(6)}.png",
                        normalize=True,
                        nrow=10,
                        range=(-0.5, 0.5),
                    )

            if i % 10000 == 0:
                torch.save(
                    model.state_dict(), f"checkpoint/model_{str(i + 1).zfill(6)}.pt"
                )
                torch.save(
                    optimizer.state_dict(), f"checkpoint/optim_{str(i + 1).zfill(6)}.pt"
                )


if __name__ == "__main__":
    args = parser.parse_args()
    print(args)

    # 初始化Glow模型
    model_single = Glow(3, args.n_flow, args.n_block, affine=args.affine, conv_lu=not args.no_lu)
    model = nn.DataParallel(model_single)
    # model = model_single
    model = model.to(device)

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    train(args, model, optimizer)

本笔记主要是对github上一个Glow-pytorch仓库的代码进行注释解解析,读者若发现问题或错误,请评论指出,互相学习。

  • 3
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值