stylegan2代码复现,适应多分类的图像生成
本文的工作
stylegan2的实现大多数已经被实现,但是大多数简单的实现不关注多类别的图像生成,本文的工作在借鉴stylegan2的pytorch实现基础上,进行修改,尽量不使用多余的代码搭建生成器和辨别器的模型
本文实现的代码链接
stylegan2
StyleGAN2 是由 NVIDIA 团队在 2020 年提出的生成对抗网络(GAN)模型,是 StyleGAN 的改进版本。它在图像生成质量、训练稳定性和生成细节控制方面均有显著提升,广泛应用于图像合成、风格迁移、图像编辑等领域。
如果对你有所帮助,请你多多支持和关注
stylegan2的工作过程:
映射网络(Mapping Network)
将输入噪声 z 转换为中间风格向量 w,解耦隐空间的语义信息。
生成器(Generator)
由多个残差块组成,每个残差块接收 w 的不同子向量,控制不同层级的特征。
权重解调(Weight Demodulation)
替代 AdaIN,通过动态调整卷积权重,避免伪影。
stylegan2的优势
生成图像质量极高,细节丰富。
隐空间解耦性好,支持细粒度编辑。
训练稳定性优于多数 GAN 变体
话不多说,接下来我们开始进入正文:
Stylegan2代码构建
Mapping的网络代码
import torch
import torch.nn as nn
class MappingNetwork(torch.nn.Module):
def __init__(self,
z_dim, # 输入潜在向量Z的维度,0表示不使用Z
c_dim, # 条件标签C的维度,0表示无标签
w_dim, # 中间潜在向量W的维度
num_ws, # 输出的中间潜在向量数量,None表示不广播
num_layers=8, # 网络层数
w_avg_beta=0.998 # W平均值的EMA衰减系数
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.num_ws = num_ws
self.num_layers = num_layers
embed_features = w_dim
self.w_avg_beta = w_avg_beta
# 条件标签C的嵌入层(若无条件则跳过)
if c_dim == 0:
embed_features = 0
layer_features = w_dim
# 定义各层输入输出维度:[输入维度, 中间层维度..., 输出维度]
features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]
# 条件嵌入网络(当c_dim>0时生效)
if c_dim > 0:
self.embed = nn.Sequential(
EqualizedLinear(c_dim, embed_features), # 均衡化线性层
nn.Linear(embed_features, embed_features)
)
# 动态构建全连接层(带LeakyReLU激活)
for idx in range(num_layers):
in_features = features_list[idx]
out_features = features_list[idx + 1]
layer = nn.Sequential(
EqualizedLinear(in_features, out_features),
nn.LeakyReLU()
)
setattr(self, f'fc{idx}', layer) # 设置层属性,如fc0, fc1,...
# 初始化W的指数移动平均(EMA)缓存
if num_ws is not None and w_avg_beta is not None:
self.register_buffer('w_avg', torch.zeros([w_dim]))
def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
x = None
# 处理输入Z:归一化二阶矩(标准化)
if self.z_dim > 0:
x = normalize_2nd_moment(z.to(torch.float32))
# 处理条件C:嵌入并归一化,与Z拼接
if self.c_dim > 0:
y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
x = torch.cat([x, y], dim=1) if x is not None else y
# 前向传播:通过所有全连接层
for idx in range(self.num_layers):
layer = getattr(self, f'fc{idx}')
x = layer(x)
# 更新W的EMA(指数移动平均)
with torch.cuda.amp.autocast(enabled=False):
if update_emas and self.w_avg_beta is not None:
self.w_avg.copy_(
x.detach().mean(dim=0).lerp(self.w_avg.to(torch.float16), self.w_avg_beta)
)
# 广播W向量(生成多个中间潜在向量)
if self.num_ws is not None:
x = x.unsqueeze(1).repeat([1, self.num_ws, 1])
# 截断技巧:控制生成结果的多样性与质量平衡
if truncation_psi != 1:
if self.num_ws is None or truncation_cutoff is None:
x = self.w_avg.lerp(x.to(torch.float32), truncation_psi)
else:
x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
return x
上述的映射网络便是将输入潜在向量 Z 和条件标签 C转换为中间潜在向量 W,通过该向量的输入进行每一个条件的映射
生成器定义
import torch
import torch.nn as nn
class Generator(torch.nn.Module):
def __init__(self,
z_dim, # 输入噪声向量Z的维度
c_dim, # 条件标签C的维度(如类别、属性)
w_dim, # 中间潜在向量W的维度
img_resolution, # 生成图像的分辨率(如256x256)
img_channels # 图像通道数(如3表示RGB)
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_channels = img_channels
# 核心组件1:合成网络(生成图像)
self.synthesis = SynthesisNetwork(
w_dim=w_dim,
img_resolution=img_resolution,
img_channels=img_channels
)
# 核心组件2:映射网络(生成中间潜在向量W)
self.num_ws = self.synthesis.num_ws # 从合成网络获取W的数量
self.mapping = MappingNetwork(
z_dim=z_dim,
c_dim=c_dim,
w_dim=w_dim,
num_ws=self.num_ws
)
# 缓存中间潜在向量W(用于调试或可视化)
self.ws = None
def getws(self):
"""获取当前生成的中间潜在向量W"""
return self.ws if self.ws is not None else None
def forward(self, z, c):
# 步骤1:通过映射网络生成中间潜在向量W
ws = self.mapping(z, c) # z: [batch, z_dim], c: [batch, c_dim]
self.ws = ws.detach() # 缓存W(不参与梯度计算)
# 步骤2:通过合成网络生成图像
img = self.synthesis(ws) # ws: [batch, num_ws, w_dim]
return img # img: [batch, channels, H, W]
此 Generator 是 StyleGAN2 的核心组件,其设计包含以下关键特征:
解耦潜在空间:通过 W 将噪声 Z 的语义信息分层解耦。
自适应风格控制:每个合成网络层接收不同的 W 子向量,独立控制不同层级的特征。
高分辨率生成:合成网络通常包含上采样块和残差连接,支持生成高清图像。
鉴别器定义
import torch
import numpy as np
class Discriminator(torch.nn.Module):
def __init__(self,
c_dim, # 条件标签C的维度(如类别标签)
img_resolution, # 输入图像分辨率(如256)
img_channels, # 输入图像通道数(如3表示RGB)
channel_base=32768, # 通道数基数(控制各层通道数)
channel_max=512, # 单层最大通道数限制
conv_clamp=256 # 卷积输出值截断范围(防止梯度爆炸)
):
super().__init__()
self.c_dim = c_dim
self.img_resolution = img_resolution
self.img_resolution_log2 = int(np.log2(img_resolution)) # 分辨率对数(如256→8)
self.img_channels = img_channels
self.block_resolutions = [2**i for i in range(self.img_resolution_log2, 2, -1)] # 处理的分辨率层级(如256,128,...,8)
# 计算各分辨率对应的通道数(随分辨率降低而增加)
channels_dict = {
res: min(channel_base // res, channel_max)
for res in self.block_resolutions + [4] # 包含最终4x4层
}
cmap_dim = channels_dict[4] # 条件映射维度
if c_dim == 0:
cmap_dim = 0 # 无条件时禁用条件映射
# 构建多分辨率判别块
cur_layer_idx = 0
for res in self.block_resolutions:
in_channels = channels_dict[res] if res < img_resolution else 0 # 输入通道数
tmp_channels = channels_dict[res] # 中间通道数
out_channels = channels_dict[res//2] # 输出通道数(下一层输入)
# 创建判别块(如b256, b128,...)
block = DiscriminatorBlock(
in_channels, tmp_channels, out_channels,
resolution=res, first_layer_idx=cur_layer_idx,
img_channels=img_channels, conv_clamp=conv_clamp
)
setattr(self, f'b{res}', block)
cur_layer_idx += block.num_layers # 更新全局层索引
# 条件映射网络(若有条件标签)
if c_dim > 0:
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None)
# 最终判别尾部网络(4x4分辨率)
self.b4 = DiscriminatorEpilogue(
channels_dict[4], cmap_dim=cmap_dim, resolution=4,
img_channels=img_channels, conv_clamp=conv_clamp
)
def forward(self, img, c):
x = None # 中间特征(初始为空)
# 逐级处理图像(从高分辨率到低分辨率)
for res in self.block_resolutions:
block = getattr(self, f'b{res}')
x, img = block(x, img) # 返回当前层特征和下采样后的
此判别器设计与生成器对称,是StyleGAN2完整对抗训练框架的核心组件之一。
无条件GAN训练:当 c_dim=0 时,判别器仅分析图像内容。
条件GAN训练:通过 C 实现类别控制(如生成指定类别的图像)。
图像质量评估:判别器输出的得分可用于量化生成图像的逼真度。