深入理解StyleGAN-v2 Generator结构

在这里插入图片描述

这篇文章主要用于 通过解读Stylegan2的generator代码深入理解模型结构
请谨慎阅读,如有错误请及时指出,感谢支持与理解!

注:解读的项目为rosinality实现的pytorch版本stylegan2。

项目地址:github


请添加图片描述

models.py文件

class Generator(nn.Module):
    def __init__(
        self,
        size,	#输出图像的size,如果是ffhq的话就是1024
        style_dim,  #style的通道深度,一般是512
        n_mlp,  #mapping netowork的层数,一般是8
        channel_multiplier=2, #channel multiplier of the generator. config-f = 2, else = 1
        blur_kernel=[1, 3, 3, 1],
        lr_mlp=0.01,
    )
# Generator
layers = [PixelNorm()]

初始的随机噪声z需要先经过一层PixelNorm,对应图中的Normalize。

PixelNorm的公式如下:
b x , y = a x , y / 1 N ∑ j = 0 N − 1 ( a x , y j ) 2 + ϵ , w h e r e   ϵ = 1 0 − 8 b_{x,y} = a_{x,y} / \sqrt{\frac{1}{N}\sum_{j=0}^{N-1}(a_{x,y}^{j})^2+\epsilon },where \ \epsilon =10^{-8} bx,y=ax,y/N1j=0N1(ax,yj)2+ϵ ,where ϵ=108
PixelNorm的作用为:Normalize了z中的每个元素到单位长度附近,避免在训练过程中逐步失控的风险。可以理解为避免输入的随机噪声出现极端权重,改善了训练的稳定性。【将每个点归一化,除以模长】

同时选择PixelNorm还有一些好处为:它没有需要训练的参数,减轻了训练开销。

for i in range(n_mlp): #n_mlp默认为8
	layers.append(
		EqualLinear(
			style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
		)
	)
self.style = nn.Sequential(*layers)

Mapping-Network由8层EqualLinear网络构成,而EqualLinear则是一个全连接层。activation存在时就是1个Linear+fused_leaky_relu,否则只是一个Linear。

self.channels = {
    4: 512,
    8: 512,
    16: 512,
    32: 512,
    64: 256 * channel_multiplier,
    128: 128 * channel_multiplier,
    256: 64 * channel_multiplier,
    512: 32 * channel_multiplier,
    1024: 16 * channel_multiplier,
}

这里是在设置通道数,字典的key代表的是分辨率层,data则是通道数。它主要用在synthesis-network中8个block下的第一个conv层,改变输出的维度。

self.input = ConstantInput(self.channels[4]) 

ConstantInput下self.input默认为一个1* channel * size * size (1 * 512 * 4 * 4)的Parameter矩阵,对应图中。forward()函数是根据input的batch数扩展自己,返回batch * 512 * 4 * 4的矩阵。

self.conv1 = StyledConv(
            self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
        )
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)

self.log_size = int(math.log(size, 2))	
self.num_layers = (self.log_size - 2) * 2 + 1	#size 1024时对应17层

self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()

in_channel = self.channels[4]

for layer_idx in range(self.num_layers):
	res = (layer_idx + 5) // 2
	shape = [1, 1, 2 ** res, 2 ** res]
	self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))

for i in range(3, self.log_size + 1):
	out_channel = self.channels[2 ** i]
    #不断改变输出维度,用于Upsample的卷积层
	self.convs.append(
		StyledConv(
		in_channel,
		out_channel,
		3,
		style_dim,
		upsample=True,
		blur_kernel=blur_kernel,
                )
            )
    #用于提取特征的卷积层
	self.convs.append(
		StyledConv(
                    out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
                )
            )
    #实现跨层链接
	self.to_rgbs.append(ToRGB(out_channel, style_dim))

	in_channel = out_channel

上面便是Synthesis NetWork18层的构建,1+8 * 2 + 1(to_rgb)

def get_latent(self, input):
	return self.style(input)

这个函数对应将z转化成w空间向量。

下面来看Generator的forward函数:

if truncation < 1:
	style_t = []

for style in styles:
	style_t.append(
		truncation_latent + truncation * (style - truncation_latent)
	)

styles = style_t

这里是StyleGan做的Truncation Trick,具体解释如下:

从数据分布来说,低概率密度的数据在网络中的表达能力很弱,直观理解就是,低概率密度的数据出现次数少,能影响网络梯度的机会也少,但并不代表低概率密度的数据不重要。可以提高数据分布的整体密度,把分布稀疏的数据点都聚拢到一起,类似于PCA,做法很简单,首先找到数据中的一个平均点,然后计算其他所有点到这个平均点的距离,对每个距离按照统一标准进行压缩,这样就能将数据点都聚拢了,但是又不会改变点与点之间的距离关系。

请添加图片描述

请添加图片描述

if len(styles) < 2:
	inject_index = self.n_latent

	if styles[0].ndim < 3:
		latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)

	else:
		latent = styles[0]

else:
	if inject_index is None:
		inject_index = random.randint(1, self.n_latent - 1)

	latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
	latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)

	latent = torch.cat([latent, latent2], 1)

以上主要是样式混合时的代码函数。当len(styles) = 2,根据inject_index数混合合并latent。

out = self.input(latent) #产生batch*512*4*4随机矩阵
out = self.conv1(out, latent[:, 0], noise=noise[0]) #latent[:,i]对应每一层style-vector

skip = self.to_rgb1(out, latent[:, 1])

i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
	self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
	out = conv1(out, latent[:, i], noise=noise1)
	out = conv2(out, latent[:, i + 1], noise=noise2)
	skip = to_rgb(out, latent[:, i + 2], skip)

	i += 2

image = skip

以上就是Synthesis网络的主要推理流程了。
在这里插入图片描述

  • 10
    点赞
  • 15
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值