这篇文章主要用于 通过解读Stylegan2的generator代码深入理解模型结构
请谨慎阅读,如有错误请及时指出,感谢支持与理解!
注:解读的项目为rosinality实现的pytorch版本stylegan2。
项目地址:github
models.py文件
class Generator(nn.Module):
def __init__(
self,
size, #输出图像的size,如果是ffhq的话就是1024
style_dim, #style的通道深度,一般是512
n_mlp, #mapping netowork的层数,一般是8
channel_multiplier=2, #channel multiplier of the generator. config-f = 2, else = 1
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
)
# Generator
layers = [PixelNorm()]
初始的随机噪声z需要先经过一层PixelNorm,对应图中的Normalize。
PixelNorm的公式如下:
b
x
,
y
=
a
x
,
y
/
1
N
∑
j
=
0
N
−
1
(
a
x
,
y
j
)
2
+
ϵ
,
w
h
e
r
e
ϵ
=
1
0
−
8
b_{x,y} = a_{x,y} / \sqrt{\frac{1}{N}\sum_{j=0}^{N-1}(a_{x,y}^{j})^2+\epsilon },where \ \epsilon =10^{-8}
bx,y=ax,y/N1j=0∑N−1(ax,yj)2+ϵ,where ϵ=10−8
PixelNorm的作用为:Normalize了z中的每个元素到单位长度附近,避免在训练过程中逐步失控的风险。可以理解为避免输入的随机噪声出现极端权重,改善了训练的稳定性。【将每个点归一化,除以模长】
同时选择PixelNorm还有一些好处为:它没有需要训练的参数,减轻了训练开销。
for i in range(n_mlp): #n_mlp默认为8
layers.append(
EqualLinear(
style_dim, style_dim, lr_mul=lr_mlp, activation="fused_lrelu"
)
)
self.style = nn.Sequential(*layers)
Mapping-Network由8层EqualLinear网络构成,而EqualLinear则是一个全连接层。activation存在时就是1个Linear+fused_leaky_relu,否则只是一个Linear。
self.channels = {
4: 512,
8: 512,
16: 512,
32: 512,
64: 256 * channel_multiplier,
128: 128 * channel_multiplier,
256: 64 * channel_multiplier,
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
这里是在设置通道数,字典的key代表的是分辨率层,data则是通道数。它主要用在synthesis-network中8个block下的第一个conv层,改变输出的维度。
self.input = ConstantInput(self.channels[4])
ConstantInput下self.input默认为一个1* channel * size * size (1 * 512 * 4 * 4)的Parameter矩阵,对应图中。forward()函数是根据input的batch数扩展自己,返回batch * 512 * 4 * 4的矩阵。
self.conv1 = StyledConv(
self.channels[4], self.channels[4], 3, style_dim, blur_kernel=blur_kernel
)
self.to_rgb1 = ToRGB(self.channels[4], style_dim, upsample=False)
self.log_size = int(math.log(size, 2))
self.num_layers = (self.log_size - 2) * 2 + 1 #size 1024时对应17层
self.convs = nn.ModuleList()
self.upsamples = nn.ModuleList()
self.to_rgbs = nn.ModuleList()
self.noises = nn.Module()
in_channel = self.channels[4]
for layer_idx in range(self.num_layers):
res = (layer_idx + 5) // 2
shape = [1, 1, 2 ** res, 2 ** res]
self.noises.register_buffer(f"noise_{layer_idx}", torch.randn(*shape))
for i in range(3, self.log_size + 1):
out_channel = self.channels[2 ** i]
#不断改变输出维度,用于Upsample的卷积层
self.convs.append(
StyledConv(
in_channel,
out_channel,
3,
style_dim,
upsample=True,
blur_kernel=blur_kernel,
)
)
#用于提取特征的卷积层
self.convs.append(
StyledConv(
out_channel, out_channel, 3, style_dim, blur_kernel=blur_kernel
)
)
#实现跨层链接
self.to_rgbs.append(ToRGB(out_channel, style_dim))
in_channel = out_channel
上面便是Synthesis NetWork18层的构建,1+8 * 2 + 1(to_rgb)
def get_latent(self, input):
return self.style(input)
这个函数对应将z转化成w空间向量。
下面来看Generator的forward函数:
if truncation < 1:
style_t = []
for style in styles:
style_t.append(
truncation_latent + truncation * (style - truncation_latent)
)
styles = style_t
这里是StyleGan做的Truncation Trick,具体解释如下:
从数据分布来说,低概率密度的数据在网络中的表达能力很弱,直观理解就是,低概率密度的数据出现次数少,能影响网络梯度的机会也少,但并不代表低概率密度的数据不重要。可以提高数据分布的整体密度,把分布稀疏的数据点都聚拢到一起,类似于PCA,做法很简单,首先找到数据中的一个平均点,然后计算其他所有点到这个平均点的距离,对每个距离按照统一标准进行压缩,这样就能将数据点都聚拢了,但是又不会改变点与点之间的距离关系。
if len(styles) < 2:
inject_index = self.n_latent
if styles[0].ndim < 3:
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
else:
latent = styles[0]
else:
if inject_index is None:
inject_index = random.randint(1, self.n_latent - 1)
latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
latent2 = styles[1].unsqueeze(1).repeat(1, self.n_latent - inject_index, 1)
latent = torch.cat([latent, latent2], 1)
以上主要是样式混合时的代码函数。当len(styles) = 2,根据inject_index数混合合并latent。
out = self.input(latent) #产生batch*512*4*4随机矩阵
out = self.conv1(out, latent[:, 0], noise=noise[0]) #latent[:,i]对应每一层style-vector
skip = self.to_rgb1(out, latent[:, 1])
i = 1
for conv1, conv2, noise1, noise2, to_rgb in zip(
self.convs[::2], self.convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
i += 2
image = skip
以上就是Synthesis网络的主要推理流程了。