StyleGAN-代码实现分析

StyleGAN

前言

StyleGAN 模型架构

一、整体思路

常用的在线官方手册:
Python官方文档
Python可视化文档
Pytorch官方文档
torch.cat
Code Specialist

StyleGAN相关网址:
StyleGAN v1的论文
StyleGAN Pytorch 第三方开源代码
代码分析1
代码分析2
G Lab实验室网址

代码复现过程遇到的问题:
CUDA 依赖版本安装管理

二、代码阅读

1. model

完整的StyleGAN结构,由Mapping Network和Synthesis network构成,如下代码实现:

## 完整的StyleGAN生成器定义
class StyledGenerator(nn.Module):
    def __init__(self, code_dim=512, n_mlp=8):
        super().__init__()

        self.generator = Generator(code_dim) ##synthesis network

        ## mapping network定义,包含8个全连接层,n_mlp=8
        layers = [PixelNorm()] 
        for i in range(n_mlp): 
            layers.append(EqualLinear(code_dim, code_dim))
            layers.append(nn.LeakyReLU(0.2))

        ## mapping network f,用于从噪声向量Z生成Latent向量W
        self.style = nn.Sequential(*layers)

    def forward(
        self,
        input, ##输入向量Z
        noise=None, ##噪声,可选的
        step=0,##上采样因子
        alpha=1,##融合因子
        mean_style=None,##平均风格向量W
        style_weight=0,##风格向量权重
        mixing_range=(-1, 1),##混合区间变量
    ):
        styles = [] ##风格向量W
        if type(input) not in (list, tuple):
            input = [input]

        #print("混合的样本数input size="+str(len(input))) ##input=(1,(n_sample, 512))
        for i in input:
            styles.append(self.style(i)) ##调用mapping network,生成第i个风格向量W

        batch = input[0].shape[0] ## batchsize大小

        if noise is None:
            noise = []

            for i in range(step + 1): ## 0~8,共9层noise
                size = 4 * 2 ** i ## 每一层的尺度,第一层为4*4,每一层的各个通道共用噪声
                noise.append(torch.randn(batch, 1, size, size, device=input[0].device))

        ## 基于平均风格向量和当前生成的风格向量,获得完整的风格向量
        if mean_style is not None:
            styles_norm = [] ##风格数组[1*512]

            for style in styles:
                styles_norm.append(mean_style + style_weight * (style - mean_style))

            styles = styles_norm
            print("has mean_style,shape="+str(len(mean_style))+' '+str(mean_style[0].shape)) #1*512

        #print("混合的样本数styles.shape="+str(len(styles))+' '+str(styles[0].shape)) #styles[0].shape=batchsize*512
        #print("total step="+str(step)) #8
        return self.generator(styles, noise, step, alpha, mixing_range=mixing_range)

    ## 从Z向量生成平均W向量
    def mean_style(self, input):
        style = self.style(input).mean(0, keepdim=True)
        return style

Synthesis network 定义如下:

## synthesis network类定义
class Generator(nn.Module):
    def __init__(self, code_dim, fused=True):
        super().__init__()
        ##  9个尺度的卷积block,从4×4到64×64,使用双线性上采样;从64×64到1024×1024,使用转置卷积进行上采样
        self.progression = nn.ModuleList(
            [
                StyledConvBlock(512, 512, 3, 1, initial=True),  # 输出为4×4 constant
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 输出为8×8
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 输出为16×16
                StyledConvBlock(512, 512, 3, 1, upsample=True),  # 输出为32×32
                StyledConvBlock(512, 256, 3, 1, upsample=True),  # 输出为64×64
                StyledConvBlock(256, 128, 3, 1, upsample=True, fused=fused),  # 输出为128×128
                StyledConvBlock(128, 64, 3, 1, upsample=True, fused=fused),  # 输出为256×256
                StyledConvBlock(64, 32, 3, 1, upsample=True, fused=fused),  # 输出为512×512
                StyledConvBlock(32, 16, 3, 1, upsample=True, fused=fused),  # 输出为1024×1024
            ]
        )
        ## 9个尺度的1*1构成的to_rgb层,输入512个通道,输出3通道RGB图像,与前面styleconvblock对应
        self.to_rgb = nn.ModuleList(
            [
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(512, 3, 1),
                EqualConv2d(256, 3, 1),
                EqualConv2d(128, 3, 1),
                EqualConv2d(64, 3, 1),
                EqualConv2d(32, 3, 1),
                EqualConv2d(16, 3, 1),
            ]
        )

    def forward(self, style, noise, step=0, alpha=1, mixing_range=(-1, 1)):
        out = noise[0] ## 取噪声向量为输入

        if len(style) < 2: ## 输入style向量只有1个,不进行样式混合,inject_index=10
            inject_index = [len(self.progression) + 1]
            print("len(style)<2")
        else:
            ## 不止一个style向量,可以进行样式混合训练,生成长度为len(style) - 1))的样式混合交叉点序列,其数值大小不超过step
            ## step=8(8次上采样),len(style)=2,inject_index是一维数组,其中数在0~7之间
            inject_index = sorted(random.sample(list(range(step)), len(style) - 1))
   
        #print("inject_index="+str(inject_index)) ##default=10
        crossover = 0 ## 初始化用于mix的位置

        ## 遍历各级分辨率
        for i, (conv, to_rgb) in enumerate(zip(self.progression, self.to_rgb)):
            print("the resolution is="+str(4*np.power(2,i)))
            if mixing_range == (-1, 1):
                ## 根据前面生成的随机数,来决定样式混合的index,只考虑两个向量进行混合的情况
                # 当i < inject_index[crossover]时,style_step = style[0]
                # 当i > inject_index[crossover]时,style_step = style[1]
                if crossover < len(inject_index) and i > inject_index[crossover]:
                    crossover = min(crossover + 1, len(style))
                print("random mix style,crossover="+str(crossover))
                style_step = style[crossover] ## 获得交叉的style起始点

            else:
                ## 根据mixing_range来觉得样式混合的区间,mixing_range[0] <= i <= mixing_range[1]取style[1],其他取style[0]
                #print("fixed mix style range:"+str(mixing_range[0])+' to '+str(mixing_range[1])+' cur i is:'+str(i))
                if mixing_range[0] <= i <= mixing_range[1]:
                    style_step = style[1] #取第2个样本样式
                    #print("choose style 2")
                else:
                    style_step = style[0] #取第1个样本样式
                    #print("choose style 1")

            if i > 0 and step > 0:
                out_prev = out
                
            ## 将噪声与风格向量输入风格模块,conv=styleconvblock
            #print("batchsize="+str(len(style_step))+",style shape="+str(style_step[0].shape))
            out = conv(out, style_step, noise[i]) 

            if i == step: ## 最后1级分辨率,输出图片
                out = to_rgb(out) ##1×1卷积

                ## 最后结果是否与上一级分辨率进行alpha融合
                if i > 0 and 0 <= alpha < 1:
                    skip_rgb = self.to_rgb[i - 1](out_prev) ##获得上一级分辨率结果进行2倍上采样
                    skip_rgb = F.interpolate(skip_rgb, scale_factor=2, mode='nearest')
                    out = (1 - alpha) * skip_rgb + alpha * out

                break

        return out
  • AdaIN
    AdaIN中相关代码:
    在这里插入图片描述
## 自适应的IN层
class AdaptiveInstanceNorm(nn.Module):
    def __init__(self, in_channel, style_dim):
        super().__init__()

        self.norm = nn.InstanceNorm2d(in_channel) ##创建IN层
        self.style = EqualLinear(style_dim, in_channel * 2) ##从W向量变成AdaIN层系数

        self.style.linear.bias.data[:in_channel] = 1
        self.style.linear.bias.data[in_channel:] = 0

    def forward(self, input, style):
        #print("AdaIN style input="+str(style.shape)) #默认值,风格向量长度512
        ## 输入style为风格向量,长度为512;经过self.style得到输出风格矩阵,通道数等于输入通道数的2倍
        style = self.style(style).unsqueeze(2).unsqueeze(3)
        gamma, beta = style.chunk(2, 1) ##获得缩放和偏置系数,按1轴(通道)分为2部分
        #print("AdaIN style output="+str(style.shape))
#等于输入通道数的2倍,in_channel*2

        out = self.norm(input) ##IN归一化
        out = gamma * out + beta

        return out

1. train_transform():图片预处理

torchvision.transforms是pytorch中的图像预处理包。一般用Compose把多个步骤整合到一起:

torch.transforms常用代码
Resize :把给定的图片resize到given size
RandomCrop:在一个随机的位置进行裁剪
ToTensor:convert a PIL image to tensor (HWC) in range [0,255] to a torch.Tensor(CHW) in the range [0.0,1.0]
ToTensor():能够把灰度范围从0-255变换到0-1之间,

def train_transform():
    transform_list = [
        transforms.Resize(size=(512, 512)),	
        transforms.RandomCrop(256),
        transforms.ToTensor()
    ]
    return transforms.Compose(transform_list)

2. argparse模块:

作用:命令行的解析器,方便编写用户友好的命令行接口。
在这里定义了输入路径,输出路径;learning_rate, learning_rate_decay;max_iter,batch_size等等。
使用三部曲:

  1. 创建解析器

parser = argparse.ArgumentParser()

  1. 添加参数

parser.add_argument(‘–content_dir’, type=str, required=True,
help=‘Directory path to a batch of content images’)

name or flags - 一个命名或者一个选项字符串的列表,例如 foo 或 -f, --foo。
default - 当参数未在命令行中出现时使用的值。
type - 命令行参数应当被转换成的类型。
required - 此命令行选项是否可省略 (仅选项可用)。
help - 一个此选项作用的简单描述。

  1. 解析

args = parser.parse_args()

以下是实例代码:

parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content_dir', type=str, required=True,
                    help='Directory path to a batch of content images')
parser.add_argument('--style_dir', type=str, required=True,
                    help='Directory path to a batch of style images')
parser.add_argument('--vgg', type=str, default='models/vgg_normalised.pth')

# training options
parser.add_argument('--save_dir', default='./experiments',
                    help='Directory to save the model')
parser.add_argument('--log_dir', default='./logs',
                    help='Directory to save the log')
parser.add_argument('--lr', type=float, default=1e-4)
parser.add_argument('--lr_decay', type=float, default=5e-5)
parser.add_argument('--max_iter', type=int, default=160000)
parser.add_argument('--batch_size', type=int, default=8)
parser.add_argument('--style_weight', type=float, default=10.0)
parser.add_argument('--content_weight', type=float, default=1.0)
parser.add_argument('--n_threads', type=int, default=16)
parser.add_argument('--save_model_interval', type=int, default=10000)

args = parser.parse_args()

3. optimizer.zero_grad(), loss.backward(), optimizer.step()

参数更新三件套
1)梯度归零(optimizer.zero_grad())
2)反向传播计算梯度(loss.backward())
3)梯度下降参数更新(optimizer.step())

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

4. model.py 阅读及解析

IN Norm

3.待读资料

从Style Transform 到StyleGAN

三、python基础知识

1.Shape

shape属性可以获取矩阵的形状(例如二维数组的行列),获取的结果是一个元组,因此相关代码如下:

import numpy as np
x = np.array([[1,2,5],[2,3,5],[3,4,5],[2,3,6]])
#输出数组的行和列数
print x.shape  #结果: (4, 3)
#只输出行数
print x.shape[0] #结果: 4
#只输出列数
print x.shape[1] #结果: 3

2.pytorch中的torch.cat()和torch.chunk()

pytorch 中张量进行拼接和分割的函数分别是 torch.cat()和torch.chunk()。
torch.cat()是将多个张量组成的元组按照指定的维度进行拼接。
torch.chunk()是对一个张量按照某个维度分割成多个子张量块。

torch.cat((x,x,x),dim) dim=0 row; dim=1 col.

>>> import torch
>>> x = torch.randn(2,3)
>>> x
tensor([[-0.5654,  0.7048,  0.5851],
        [-1.3871,  0.5481,  0.3028]])
>>> torch.cat((x,x,x),0)					##0代表按行拼接
tensor([[-0.5654,  0.7048,  0.5851],
        [-1.3871,  0.5481,  0.3028],
        [-0.5654,  0.7048,  0.5851],
        [-1.3871,  0.5481,  0.3028],
        [-0.5654,  0.7048,  0.5851],
        [-1.3871,  0.5481,  0.3028]])
>>> torch.cat((x,x,x),1)					##1代表按列拼接
tensor([[-0.5654,  0.7048,  0.5851, -0.5654,  0.7048,  0.5851, -0.5654,  0.7048,
          0.5851],
        [-1.3871,  0.5481,  0.3028, -1.3871,  0.5481,  0.3028, -1.3871,  0.5481,
          0.3028]])

torch.chunk(chunks,dim) dim=0 row; dim=1 col.

>>> import torch
>>> x = torch.randn(8,8)
>>> x
tensor([[ 1.0272,  1.5964,  0.1502,  1.3435, -0.1774,  0.7908,  0.6920,  1.0908],
        [ 0.8614, -0.3212,  0.4715,  0.1476,  1.7950,  1.8308, -0.1419, -0.1448],
        [-0.7407,  0.5510,  0.1284,  0.1485,  0.2997, -0.8133,  1.5608,  0.0682],
        [ 0.7217,  0.5292,  0.2469,  0.1823, -0.6200,  0.9436, -0.5221, -0.9343],
        [-2.0195, -2.3613, -0.6441, -1.7863,  1.4207,  0.4124,  0.5508, -0.2569],
        [ 0.4582, -1.6445, -0.6813, -0.8802,  0.9870, -0.6599, -0.4719,  0.3088],
        [-1.6415, -0.9834,  0.1687,  0.0159,  0.4456, -0.1823,  0.9652, -0.2785],
        [ 0.8765,  0.8214,  1.0971, -0.4150, -0.9499, -0.5875, -1.3902, -0.9129]])
>>> x.chunk(chunks=2, dim=0)					####0代表按行拆分
(tensor([[ 1.0272,  1.5964,  0.1502,  1.3435, -0.1774,  0.7908,  0.6920,  1.0908],
        [ 0.8614, -0.3212,  0.4715,  0.1476,  1.7950,  1.8308, -0.1419, -0.1448],
        [-0.7407,  0.5510,  0.1284,  0.1485,  0.2997, -0.8133,  1.5608,  0.0682],
        [ 0.7217,  0.5292,  0.2469,  0.1823, -0.6200,  0.9436, -0.5221, -0.9343]]), 
 tensor([[-2.0195, -2.3613, -0.6441, -1.7863,  1.4207,  0.4124,  0.5508, -0.2569],
        [ 0.4582, -1.6445, -0.6813, -0.8802,  0.9870, -0.6599, -0.4719,  0.3088],
        [-1.6415, -0.9834,  0.1687,  0.0159,  0.4456, -0.1823,  0.9652, -0.2785],
        [ 0.8765,  0.8214,  1.0971, -0.4150, -0.9499, -0.5875, -1.3902, -0.9129]]))
>>> x.chunk(chunks=2, dim=1)				####1代表按列拆分
(tensor([[ 1.0272,  1.5964,  0.1502,  1.3435],
        [ 0.8614, -0.3212,  0.4715,  0.1476],
        [-0.7407,  0.5510,  0.1284,  0.1485],
        [ 0.7217,  0.5292,  0.2469,  0.1823],
        [-2.0195, -2.3613, -0.6441, -1.7863],
        [ 0.4582, -1.6445, -0.6813, -0.8802],
        [-1.6415, -0.9834,  0.1687,  0.0159],
        [ 0.8765,  0.8214,  1.0971, -0.4150]]), 
 tensor([[-0.1774,  0.7908,  0.6920,  1.0908],
        [ 1.7950,  1.8308, -0.1419, -0.1448],
        [ 0.2997, -0.8133,  1.5608,  0.0682],
        [-0.6200,  0.9436, -0.5221, -0.9343],
        [ 1.4207,  0.4124,  0.5508, -0.2569],
        [ 0.9870, -0.6599, -0.4719,  0.3088],
        [ 0.4456, -0.1823,  0.9652, -0.2785],
        [-0.9499, -0.5875, -1.3902, -0.9129]]))

3.pytorch 中的 torch.squeeze()和torch.unsqueeze()

torch.squeeze 对张量维度进行压缩,压缩掉维度为1 的

总结

提示:这里对文章进行总结:

  • 8
    点赞
  • 43
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值