import argparse
import os
import numpy as np
import time
import datetime
import sys
import torch
import torchvision as tv

import torchvision.transforms as transforms
from torchvision.utils import save_image

from import DataLoader
from torch.autograd import Variable
def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ("yes", "true", "t", "y", "1"):
        return True
    elif v.lower() in ("no", "false", "f", "n", "0"):
        return False
        raise argparse.ArgumentTypeError("Boolean value expected.")

parser = argparse.ArgumentParser("Train Progressively grown GAN",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter, )
parser.add_argument("--rec_dir", action="store", type=str2bool, default=True,
                    help="whether images stored under one folder or has a recursive dir structure",
                    required=False, )
parser.add_argument("--flip_horizontal", action="store", type=str2bool, default=True,
                    help="whether to apply mirror augmentation", required=False, )
parser.add_argument("--depth", action="store", type=int, default=5,
                    help="对应generator和discriminator的深度,depth=5适用于渐进训练到32*32网络也就是cifar-10,高分辨率的以此类推是2的n次方n=depth的分辨率", required=False, )
parser.add_argument("--num_channels", action="store", type=int, default=3,
                    help="输入图像的通道数,默认是RGB三通道", required=False, )
parser.add_argument("--latent_size", action="store", type=int, default=128,
                    help="generator与discriminator中间的隐层神经元节点个数,如果是高分辨率图像可以改为512", required=False, )
parser.add_argument("--use_eql", action="store", type=str2bool, default=True,
                    help="控制是否使用均衡学习率策略", required=False, )
parser.add_argument("--use_ema", action="store", type=str2bool, default=True,
                    help="是否使用指数滑动平均策略", required=False, )
parser.add_argument("--ema_beta", action="store", type=float, default=0.999, help="value of the ema beta",
                    required=False, )
parser.add_argument("--epochs", action="store", type=int, required=False, nargs="+",
                    default=[172 for _ in range(9)], help="Mapper network configuration", )
parser.add_argument("--batch_sizes", action="store", type=int, required=False, nargs="+",
                    default=[512, 256, 128, 128, ], help="Mapper network configuration", )
parser.add_argument("--fade_in_percentages", action="store", type=int, required=False, nargs="+",
                    default=[50 for _ in range(9)], help="百分之多少的epoch采用平滑过渡也就是渐进式训练的那个阿尔法进行平滑过渡", )
args = parser.parse_known_args()[0]
num_epochs = [4, 8, 16, 20]  #分别对应渐进式训练的1,2,3,4层epoch数,分辨率越高训练epoch越多
fade_ins = [50, 50, 50, 50]  #平滑过度在所有epoch中占比,这里都是前一半进行过度,后一半epoch阿尔法完全变成1
gen_learning_rate = 0.003
dis_learning_rate = 0.003


#加载的数据为 256*256 的建筑正面照片和分割图
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

data_path = "../DATA/"
transforms = tv.transforms.Compose(
    [tv.transforms.ToTensor(), tv.transforms.Normalize([0.5], [0.5])]

trainset = tv.datasets.CIFAR10(root=data_path,
from import DataLoader, Dataset

def get_data_loader(dataset: Dataset, batch_size: int, num_workers: int = 3) -> DataLoader:
    generate the data_loader from the given dataset
        dataset: Torch dataset object
        batch_size: batch size for training
        num_workers: num of parallel readers for reading the data
    Returns: dataloader for the dataset
    return DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=True,)


from torch.autograd import Variable
import matplotlib.pyplot as plt

def show_img(img, trans=True):
    if trans:
        img = np.transpose(img.detach().cpu().numpy(), (1, 2, 0))  # 把channel维度放到最后
        plt.imshow(img / 2 + 0.5)

for i in range(2):
    sample = trainset[i]
    show_img(sample[0], trans=True)





from torch.nn import Conv2d, ConvTranspose2d, Linear
from torch import Tensor

class EqualizedConv2d(Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, 
                 bias=True, padding_mode="zeros", ) -> None:
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias,
                         padding_mode, )
        # make sure that the self.weight and self.bias are initialized according to random normal distribution
        if bias:

        # define the scale for the weights:
        fan_in = * self.in_channels
        self.scale = np.sqrt(2) / np.sqrt(fan_in)

    def forward(self, x: Tensor) -> Tensor:
        return torch.conv2d(input=x, weight=self.weight * self.scale, bias=self.bias, stride=self.stride,
                            padding=self.padding, dilation=self.dilation, groups=self.groups, )

class EqualizedConvTranspose2d(ConvTranspose2d):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, groups=1,
                 bias=True, dilation=1, padding_mode="zeros", ) -> None:
        super().__init__(in_channels, out_channels, kernel_size, stride, padding, output_padding, groups, bias,
                         dilation, padding_mode, )
        # make sure that the self.weight and self.bias are initialized according to
        # random normal distribution
        if bias:

        # define the scale for the weights:
        fan_in = self.in_channels
        self.scale = np.sqrt(2) / np.sqrt(fan_in)

    def forward(self, x: Tensor, output_size: Any = None) -> Tensor:
        output_padding = self._output_padding(
            input, output_size, self.stride, self.padding, self.kernel_size)
        return torch.conv_transpose2d(input=x, weight=self.weight * self.scale, bias=self.bias, 
                                      stride=self.stride, padding=self.padding, output_padding=output_padding, 
                                      groups=self.groups, dilation=self.dilation, )

class EqualizedLinear(Linear):
    def __init__(self, in_features, out_features, bias=True) -> None:
        super().__init__(in_features, out_features, bias)

        # make sure that the self.weight and self.bias are initialized according to
        # random normal distribution
        if bias:

        # define the scale for the weights:
        fan_in = self.in_features
        self.scale = np.sqrt(2) / np.sqrt(fan_in)

    def forward(self, x: Tensor) -> Tensor:
        return torch.nn.functional.linear(x, self.weight * self.scale, self.bias)



class MinibatchStdDev(torch.nn.Module):
    Minibatch standard deviation layer for the discriminator
        group_size: Size of each group into which the batch is split

    def __init__(self, group_size: int = 4) -> None:
            group_size: 将batchsize切分成group_size个小的mini_batch
        super(MinibatchStdDev, self).__init__()
        self.group_size = group_size

    def extra_repr(self) -> str:
        return "group_size={self.group_size}"

    def forward(self, x: Tensor, alpha: float = 1e-8) -> Tensor:
        forward pass of the layer
            x: input activation volume
            alpha: small number for numerical stability
        Returns: y => x appended with standard deviation constant map
        batch_size, channels, height, width = x.shape
        if batch_size > self.group_size:
            assert batch_size % self.group_size == 0, (
                "batch_size {batch_size} should be "
                "perfectly divisible by group_size {self.group_size}"
            group_size = self.group_size
            group_size = batch_size

        # reshape x into a more amenable sized tensor
        y = torch.reshape(x, [group_size, -1, channels, height, width])

        # indicated shapes are after performing the operation
        y = y - y.mean(dim=0, keepdim=True)  # [G x M x C x H x W] 减去group的平均值
        y = torch.sqrt((y * y).mean(dim=0, keepdim=False) + alpha)  # [M x C x H x W] 计算每个group的标准偏差
        y = y.mean(dim=[1, 2, 3], keepdim=True)  # [M x 1 x 1 x 1]  取特征图和像素的平均值        
        y = y.repeat(group_size, 1, height, width)  # [B x 1 x H x W]  复制group和像素
        y =[x, y], 1)   # [B x (C + 1) x H x W]  追加为新的 feature_map.

        # return the computed values:
        return y

3.1.3 像素归一化


class PixelwiseNorm(torch.nn.Module):
    def __init__(self):
        super(PixelwiseNorm, self).__init__()

    def forward(x: Tensor, alpha: float = 1e-8) -> Tensor:
        y = x.pow(2.0).mean(dim=1, keepdim=True).add(alpha).sqrt()  # [N1HW]  在channel维度上求平均值,keepdim=true保留了像素维度,再加上一个alpha再开根号,相当于是求出了每个样本,每个像素上的标准差
        y = x / y  # normalize the input x volume
        return y


从4 * 4 生成到 32 * 32,有4个block,其中1个输出层block,3个general block,使用了pixel-wise normalization,激活函数使用lrelu,使用反卷积和最近邻resize进行上采样

3.2.1General Block
from typing import Any, Dict, Optional
from torch.nn.functional import interpolate
from torch.nn import AvgPool2d, Conv2d, ConvTranspose2d, Embedding, LeakyReLU, Module

class GenGeneralConvBlock(torch.nn.Module):
    Module implementing a general convolutional block
        in_channels: number of input channels to the block
        out_channels: number of output channels required
        use_eql: whether to use equalized learning rate

    def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None:
        super(GenGeneralConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = in_channels
        self.use_eql = use_eql

        ConvBlock = EqualizedConv2d if use_eql else Conv2d

        self.conv_1 = ConvBlock(in_channels, out_channels, (3, 3), padding=1, bias=True)
        self.conv_2 = ConvBlock(out_channels, out_channels, (3, 3), padding=1, bias=True)
        self.pixNorm = PixelwiseNorm()
        self.lrelu = LeakyReLU(0.2)

    def forward(self, x: Tensor) -> Tensor:
        y = interpolate(x, scale_factor=2)   #interpolate是插值算法可以上采样也可以下采样,scale_factor=2代表每个维度上采样为原来两倍,某人采用最近邻的算法
        y = self.pixNorm(self.lrelu(self.conv_1(y)))
        y = self.pixNorm(self.lrelu(self.conv_2(y)))

        return y
3.2.2 GenInitialBlock
class GenInitialBlock(Module):
    Module implementing the initial block of the input
        in_channels: number of input channels to the block
        out_channels: number of output channels of the block
        use_eql: whether to use equalized learning rate

    def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None:
        super(GenInitialBlock, self).__init__()
        self.use_eql = use_eql

        ConvBlock = EqualizedConv2d if use_eql else Conv2d
        ConvTransposeBlock = EqualizedConvTranspose2d if use_eql else ConvTranspose2d

        self.conv_1 = ConvTransposeBlock(in_channels, out_channels, (4, 4), bias=True)
        self.conv_2 = ConvBlock(
            out_channels, out_channels, (3, 3), padding=1, bias=True
        self.pixNorm = PixelwiseNorm()
        self.lrelu = LeakyReLU(0.2)

    def forward(self, x: Tensor) -> Tensor:
        y = torch.unsqueeze(torch.unsqueeze(x, -1), -1)
        y = self.pixNorm(y)  # normalize the latents to hypersphere
        y = self.lrelu(self.conv_1(y))  
        y = self.lrelu(self.conv_2(y))
        y = self.pixNorm(y)
        return y
3.2.3 计算feature map的函数
from torch.nn import Conv2d, ModuleList

def nf(stage: int, fmap_base: int = 16 << 10, fmap_decay: float = 1.0, fmap_min: int = 1, 
       fmap_max: int = 512,) -> int:
    computes the number of fmaps present in each stage
        stage: stage level
        fmap_base: base number of fmaps
        fmap_decay: decay rate for the fmaps in the network
        fmap_min: minimum number of fmaps
        fmap_max: maximum number of fmaps

    Returns: number of fmaps that should be present there
    return int(np.clip(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_min, fmap_max,).item())

class Generator(torch.nn.Module):
    Generator Module (block) of the GAN network
        depth: required depth of the Network
        num_channels: number of output channels (default = 3 for RGB)
        latent_size: size of the latent manifold
        use_eql: whether to use equalized learning rate

    def __init__(self, depth: int = 10, num_channels: int = 3, latent_size: int = 512, 
                 use_eql: bool = True,) -> None:

        # object state:
        self.depth = depth
        self.latent_size = latent_size
        self.num_channels = num_channels
        self.use_eql = use_eql

        ConvBlock = EqualizedConv2d if use_eql else Conv2d

        self.layers = ModuleList([GenInitialBlock(latent_size, nf(1), use_eql=self.use_eql)])
        for stage in range(1, depth - 1):
            self.layers.append(GenGeneralConvBlock(nf(stage), nf(stage + 1), use_eql))

        self.rgb_converters = ModuleList(
            [ConvBlock(nf(stage), num_channels, kernel_size=(1, 1)) for stage in range(1, depth)])

    def forward(self, x: Tensor, depth: int, alpha: float) -> Tensor:
        forward pass of the Generator
            x: input latent noise
            depth: depth from where the network's output is required
            alpha: value of alpha for fade-in effect
        Returns: generated images at the give depth's resolution

        assert depth <= self.depth, "Requested output depth {depth} cannot be produced"

        if depth == 2:
            y = self.rgb_converters[0](self.layers[0](x))
            y = x
            for layer_block in self.layers[: depth - 2]:
                y = layer_block(y)
            residual = interpolate(self.rgb_converters[depth - 3](y), scale_factor=2)
            straight = self.rgb_converters[depth - 2](self.layers[depth - 2](y))
            y = (alpha * straight) + ((1 - alpha) * residual)
        return y

    def get_save_info(self) -> Dict[str, Any]:
        return {
            "conf": {
                "depth": self.depth,
                "num_channels": self.num_channels,
                "latent_size": self.latent_size,
                "use_eql": self.use_eql,
            "state_dict": self.state_dict(),

generator = Generator(depth=args.depth, num_channels=args.num_channels, latent_size=args.latent_size, 



输入从 4 * 4 到 32 * 32,同样有4个block,其中1个输出层block,3个general block,使用了minibtach STD层,激活函数使用lrelu,使用avg pool进行下采样

3.3.1 DisGenralConvBlock
class DisGeneralConvBlock(torch.nn.Module):
    General block in the discriminator
        in_channels: number of input channels
        out_channels: number of output channels
        use_eql: whether to use equalized learning rate

    def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None:
        super(DisGeneralConvBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_eql = use_eql

        ConvBlock = EqualizedConv2d if use_eql else Conv2d

        self.conv_1 = ConvBlock(in_channels, in_channels, (3, 3), padding=1, bias=True)
        self.conv_2 = ConvBlock(in_channels, out_channels, (3, 3), padding=1, bias=True)
        self.downSampler = AvgPool2d(2)
        self.lrelu = LeakyReLU(0.2)

    def forward(self, x: Tensor) -> Tensor:
        y = self.lrelu(self.conv_1(x))
        y = self.lrelu(self.conv_2(y))
        y = self.downSampler(y)
        return y

3.3.2 DisFinalBlock 判别器输出层
class DisFinalBlock(torch.nn.Module):
    Final block for the Discriminator
        in_channels: number of input channels
        use_eql: whether to use equalized learning rate

    def __init__(self, in_channels: int, out_channels: int, use_eql: bool) -> None:
        super(DisFinalBlock, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_eql = use_eql

        ConvBlock = EqualizedConv2d if use_eql else Conv2d

        self.conv_1 = ConvBlock(
            in_channels + 1, in_channels, (3, 3), padding=1, bias=True
        self.conv_2 = ConvBlock(in_channels, out_channels, (4, 4), bias=True)
        self.conv_3 = ConvBlock(out_channels, 1, (1, 1), bias=True)
        self.batch_discriminator = MinibatchStdDev()
        self.lrelu = LeakyReLU(0.2)

    def forward(self, x: Tensor) -> Tensor:
        y = self.batch_discriminator(x)
        y = self.lrelu(self.conv_1(y))
        y = self.lrelu(self.conv_2(y))
        y = self.conv_3(y)
        return y.view(-1)
3.3.3 Discriminator 判别器构建
class Discriminator(torch.nn.Module):
    Discriminator of the GAN
        depth: depth of the discriminator. log_2(resolution)
        num_channels: number of channels of the input images (Default = 3 for RGB)
        latent_size: latent size of the final layer
        use_eql: whether to use the equalized learning rate
        num_classes: number of classes for a conditional discriminator (Default = None)
                     meaning unconditional discriminator

    def __init__(
        self, depth: int = 7, num_channels: int = 3, latent_size: int = 512, use_eql: bool = True,
        num_classes: Optional[int] = None,) -> None:
        self.depth = depth
        self.num_channels = num_channels
        self.latent_size = latent_size
        self.use_eql = use_eql
        self.num_classes = num_classes

        ConvBlock = EqualizedConv2d if use_eql else Conv2d

        self.layers = [DisFinalBlock(nf(1), latent_size, use_eql)]

        for stage in range(1, depth - 1):
            self.layers.insert(0, DisGeneralConvBlock(nf(stage + 1), nf(stage), use_eql))
        self.layers = ModuleList(self.layers)
        self.from_rgb = ModuleList(
            reversed([ConvBlock(num_channels, nf(stage), kernel_size=(1, 1)) for stage in range(1, depth)]))

    def forward(self, x: Tensor, depth: int, alpha: float) -> Tensor:
        forward pass of the discriminator
            x: input to the network
            depth: current depth of operation (Progressive GAN)
            alpha: current value of alpha for fade-in
            labels: labels for conditional discriminator (Default = None)
                    shape => (Batch_size,) shouldn't be a column vector

        Returns: raw discriminator scores
        assert (depth <= self.depth), "Requested output depth {depth} cannot be evaluated"

        if depth > 2:
            residual = self.from_rgb[-(depth - 2)](avg_pool2d(x, kernel_size=2, stride=2))
            straight = self.layers[-(depth - 1)](self.from_rgb[-(depth - 1)](x))
            y = (alpha * straight) + ((1 - alpha) * residual)
            for layer_block in self.layers[-(depth - 2) : -1]:
                y = layer_block(y)
            y = self.from_rgb[-1](x)
        y = self.layers[-1](y)
        return y

    def get_save_info(self) -> Dict[str, Any]:
        return {
            "conf": {
                "depth": self.depth,
                "num_channels": self.num_channels,
                "latent_size": self.latent_size,
                "use_eql": self.use_eql,
                "num_classes": self.num_classes,
            "state_dict": self.state_dict(),
discriminator = Discriminator(depth=args.depth, num_channels=args.num_channels, latent_size=args.latent_size, 


使用 WGAN-GP loss

class WganGP(GANLoss):
    Wgan-GP loss function. The discriminator is required for computing the gradient
        drift: weight for the drift penalty

    def __init__(self, drift: float = 0.001) -> None:
        self.drift = drift

    def _gradient_penalty(dis: Discriminator, real_samples: Tensor, fake_samples: Tensor, depth: int, 
                          alpha: float, reg_lambda: float = 10, labels: Optional[Tensor] = None, ) -> Tensor:
        private helper for calculating the gradient penalty
            dis: the discriminator used for computing the penalty
            real_samples: real samples
            fake_samples: fake samples
            depth: current depth in the optimization
            alpha: current alpha for fade-in
            reg_lambda: regularisation lambda
        Returns: computed gradient penalty
        batch_size = real_samples.shape[0]

        # generate random epsilon
        epsilon = torch.rand((batch_size, 1, 1, 1)).to(real_samples.device)

        # create the merge of both real and fake samples
        merged = epsilon * real_samples + ((1 - epsilon) * fake_samples)

        # forward pass
        op = dis(merged, depth, alpha)

        # perform backward pass from op to merged for obtaining the gradients
        gradient = torch.autograd.grad(outputs=op, inputs=merged, grad_outputs=torch.ones_like(op), 
                                       create_graph=True, retain_graph=True, only_inputs=True, )[0]

        # calculate the penalty using these gradients
        gradient = gradient.view(gradient.shape[0], -1)
        penalty = reg_lambda * ((gradient.norm(p=2, dim=1) - 1) ** 2).mean()

        # return the calculated penalty:
        return penalty

    def dis_loss(self, discriminator: Discriminator, real_samples: Tensor, fake_samples: Tensor, depth: int,
                 alpha: float, labels: Optional[Tensor] = None, ) -> Tensor:
        real_scores = discriminator(real_samples, depth, alpha)
        fake_scores = discriminator(fake_samples, depth, alpha)
        loss = (torch.mean(fake_scores) - torch.mean(real_scores) + (self.drift * torch.mean(real_scores ** 2)))

        # calculate the WGAN-GP (gradient penalty)
        gp = self._gradient_penalty(discriminator, real_samples, fake_samples, depth, alpha)
        loss += gp

        return loss

    def gen_loss(self, discriminator: Discriminator, _: Tensor, fake_samples: Tensor, depth: int, alpha: float,
                 labels: Optional[Tensor] = None, ) -> Tensor:
        fake_scores = discriminator(fake_samples, depth, alpha)
        return -torch.mean(fake_scores)

loss_fn = WganGP()


cuda = True if torch.cuda.is_available() else False
print("cuda_is_available =", cuda)
if cuda:
    generator = generator.cuda()
    discriminator = discriminator.cuda()
device = torch.device("cuda" if cuda else "cpu")


gen_optim = torch.optim.Adam(
    betas=(0, 0.99),
dis_optim = torch.optim.Adam(
    betas=(0, 0.99),



from torch.nn.functional import avg_pool2d, interpolate

def progressive_downsample_batch(real_batch, depth, alpha):
    private helper for downsampling the original images in order to facilitate the
    progressive growing of the layers.
        real_batch: batch of real samples
        depth: depth at which training is going on
        alpha: current value of the fader alpha
    Returns: modified real batch of samples
    # downsample the real_batch for the given depth
    down_sample_factor = int(2 ** (generator.depth - depth))
    prior_downsample_factor = int(2 ** (generator.depth - depth + 1))
    ds_real_samples = avg_pool2d(real_batch, kernel_size=down_sample_factor, stride=down_sample_factor)

    if depth > 2:
        prior_ds_real_samples = interpolate(
            avg_pool2d(real_batch, kernel_size=prior_downsample_factor, stride=prior_downsample_factor,),
            scale_factor=2, )
        prior_ds_real_samples = ds_real_samples

    # real samples are a linear combination of ds_real_samples and prior_ds_real_samples
    real_samples = (alpha * ds_real_samples) + ((1 - alpha) * prior_ds_real_samples)
    return real_samples

from torch.optim.optimizer import Optimizer

def optimize_discriminator(loss: GANLoss, dis_optim: Optimizer, noise: Tensor, real_batch: Tensor, 
                           depth: int, alpha: float, labels: Optional[Tensor] = None, ) -> float:
    performs a single weight update step on discriminator using the batch of data
    and the noise
        loss: the loss function to be used for the optimization
        dis_optim: discriminator optimizer
        noise: input noise for sample generation
        real_batch: real samples batch
        depth: current depth of optimization
        alpha: current alpha for fade-in
        labels: labels for conditional discrimination
    Returns: discriminator loss value
    real_samples = progressive_downsample_batch(real_batch, depth, alpha)

    # generate a batch of samples
    fake_samples = generator(noise, depth, alpha).detach()
    dis_loss = loss.dis_loss(discriminator, real_samples, fake_samples, depth, alpha)

    # optimize discriminator

    return dis_loss.item()

def optimize_generator(loss: GANLoss, gen_optim: Optimizer, noise: Tensor, real_batch: Tensor, depth: int,
                       alpha: float, labels: Optional[Tensor] = None, ) -> float:
    performs a single weight update step on generator using the batch of data
    and the noise
        loss: the loss function to be used for the optimization
        gen_optim: generator optimizer
        noise: input noise for sample generation
        real_batch: real samples batch
        depth: current depth of optimization
        alpha: current alpha for fade-in
        labels: labels for conditional discrimination

    Returns: generator loss value
    real_samples = progressive_downsample_batch(real_batch, depth, alpha)

    # generate fake samples:
    fake_samples = generator(noise, depth, alpha)
    gen_loss = loss.gen_loss(discriminator, real_samples, fake_samples, depth, alpha)

    # optimize the generator

    return gen_loss.item()



import timeit

print("Starting the training process ... ")
start_depth = 2
num_workers = 0
batch_repeats = 2
batch_sizes = args.batch_sizes
epochs = num_epochs
dataset = trainset
fade_in_percentages = fade_ins
global_step = 0
global_time = 0
for current_depth in range(start_depth, generator.depth + 1):
    current_res = int(2 ** current_depth)
    print("\n\nCurrently working on Depth: {current_depth}")
    print("Current resolution: %d x %d" % (current_res, current_res))
    depth_list_index = current_depth - 2
    current_batch_size = batch_sizes[depth_list_index]
    data = get_data_loader(dataset, current_batch_size, num_workers)
    ticker = 1
    for epoch in range(1, epochs[depth_list_index] + 1):
        start = timeit.default_timer()  # record time at the start of epoch
        print("\nEpoch: {epoch}")
        total_batches = len(data)
        # compute the fader point
        fader_point = int((
            fade_in_percentages[depth_list_index] / 100) * epochs[depth_list_index] * total_batches)

        for (i, batch) in enumerate(data, start=1):
            # calculate the alpha for fading in the layers
            alpha = ticker / fader_point if ticker <= fader_point else 1
            # extract current batch of data for training
            print("label =", classes[batch[1][0]])
            show_img(batch[0][0], True)
            batch = batch[0]
            images =
            gan_input = torch.randn(current_batch_size, generator.latent_size).to(device)
            print("z.shape =", gan_input.shape)
            gen_loss, dis_loss = None, None
            for _ in range(batch_repeats):
                dis_loss = optimize_discriminator(loss_fn, dis_optim, gan_input, images, current_depth, alpha)
                gen_loss = optimize_generator(loss_fn, gen_optim, gan_input, images, current_depth, alpha)
                print("gen_loss =", gen_loss)
                print("dis_loss =", dis_loss)
            global_step += 1
            # log 输出
            if (i == 1):
                elapsed = time.time() - global_time
                elapsed = str(datetime.timedelta(seconds=elapsed))
                print("Elapsed: [%s]  batch: %d  d_loss: %f  g_loss: %f" % (elapsed, i, dis_loss, gen_loss))

            # increment the alpha ticker and the step
            ticker += 1

        stop = timeit.default_timer()
        print("Time taken for epoch: %.3f secs" % (stop - start))
