stylegan2代码复现,适应多分类的图像生成

本文的工作

stylegan2的实现大多数已经被实现,但是大多数简单的实现不关注多类别的图像生成,本文的工作在借鉴stylegan2的pytorch实现基础上,进行修改,尽量不使用多余的代码搭建生成器和辨别器的模型
本文实现的代码链接

stylegan2

StyleGAN2 是由 NVIDIA 团队在 2020 年提出的生成对抗网络(GAN)模型,是 StyleGAN 的改进版本。它在图像生成质量、训练稳定性和生成细节控制方面均有显著提升,广泛应用于图像合成、风格迁移、图像编辑等领域。
如果对你有所帮助,请你多多支持和关注

stylegan2的工作过程:

工作过程
在这里插入图片描述
在这里插入图片描述

映射网络(Mapping Network)
将输入噪声 z 转换为中间风格向量 w,解耦隐空间的语义信息。
生成器(Generator)
由多个残差块组成,每个残差块接收 w 的不同子向量,控制不同层级的特征。
权重解调(Weight Demodulation)
替代 AdaIN,通过动态调整卷积权重,避免伪影。

stylegan2的优势

生成图像质量极高,细节丰富。

隐空间解耦性好,支持细粒度编辑。

训练稳定性优于多数 GAN 变体

话不多说,接下来我们开始进入正文:

Stylegan2代码构建

Mapping的网络代码

import torch
import torch.nn as nn

class MappingNetwork(torch.nn.Module):
    def __init__(self,
                 z_dim,       # 输入潜在向量Z的维度,0表示不使用Z
                 c_dim,       # 条件标签C的维度,0表示无标签
                 w_dim,       # 中间潜在向量W的维度
                 num_ws,      # 输出的中间潜在向量数量,None表示不广播
                 num_layers=8, # 网络层数
                 w_avg_beta=0.998  # W平均值的EMA衰减系数
                 ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.w_dim = w_dim
        self.num_ws = num_ws
        self.num_layers = num_layers
        embed_features = w_dim
        self.w_avg_beta = w_avg_beta

        # 条件标签C的嵌入层(若无条件则跳过)
        if c_dim == 0:
            embed_features = 0
        layer_features = w_dim
        # 定义各层输入输出维度:[输入维度, 中间层维度..., 输出维度]
        features_list = [z_dim + embed_features] + [layer_features] * (num_layers - 1) + [w_dim]

        # 条件嵌入网络(当c_dim>0时生效)
        if c_dim > 0:
            self.embed = nn.Sequential(
                EqualizedLinear(c_dim, embed_features),  # 均衡化线性层
                nn.Linear(embed_features, embed_features)
            )
        
        # 动态构建全连接层(带LeakyReLU激活)
        for idx in range(num_layers):
            in_features = features_list[idx]
            out_features = features_list[idx + 1]
            layer = nn.Sequential(
                EqualizedLinear(in_features, out_features),
                nn.LeakyReLU()
            )
            setattr(self, f'fc{idx}', layer)  # 设置层属性,如fc0, fc1,...

        # 初始化W的指数移动平均(EMA)缓存
        if num_ws is not None and w_avg_beta is not None:
            self.register_buffer('w_avg', torch.zeros([w_dim]))

    def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
        x = None
        # 处理输入Z:归一化二阶矩(标准化)
        if self.z_dim > 0:
            x = normalize_2nd_moment(z.to(torch.float32))
        # 处理条件C:嵌入并归一化,与Z拼接
        if self.c_dim > 0:
            y = normalize_2nd_moment(self.embed(c.to(torch.float32)))
            x = torch.cat([x, y], dim=1) if x is not None else y

        # 前向传播:通过所有全连接层
        for idx in range(self.num_layers):
            layer = getattr(self, f'fc{idx}')
            x = layer(x)

        # 更新W的EMA(指数移动平均)
        with torch.cuda.amp.autocast(enabled=False):
            if update_emas and self.w_avg_beta is not None:
                self.w_avg.copy_(
                    x.detach().mean(dim=0).lerp(self.w_avg.to(torch.float16), self.w_avg_beta)
                )

        # 广播W向量(生成多个中间潜在向量)
        if self.num_ws is not None:
            x = x.unsqueeze(1).repeat([1, self.num_ws, 1])

        # 截断技巧:控制生成结果的多样性与质量平衡
        if truncation_psi != 1:
            if self.num_ws is None or truncation_cutoff is None:
                x = self.w_avg.lerp(x.to(torch.float32), truncation_psi)
            else:
                x[:, :truncation_cutoff] = self.w_avg.lerp(x[:, :truncation_cutoff], truncation_psi)
        return x

上述的映射网络便是将输入潜在向量 Z 和条件标签 C转换为中间潜在向量 W,通过该向量的输入进行每一个条件的映射

生成器定义

import torch
import torch.nn as nn

class Generator(torch.nn.Module):
    def __init__(self,
                 z_dim,            # 输入噪声向量Z的维度
                 c_dim,            # 条件标签C的维度(如类别、属性)
                 w_dim,            # 中间潜在向量W的维度
                 img_resolution,   # 生成图像的分辨率(如256x256)
                 img_channels      # 图像通道数(如3表示RGB)
                 ):
        super().__init__()
        self.z_dim = z_dim
        self.c_dim = c_dim
        self.w_dim = w_dim
        self.img_resolution = img_resolution
        self.img_channels = img_channels

        # 核心组件1:合成网络(生成图像)
        self.synthesis = SynthesisNetwork(
            w_dim=w_dim,
            img_resolution=img_resolution,
            img_channels=img_channels
        )
        # 核心组件2:映射网络(生成中间潜在向量W)
        self.num_ws = self.synthesis.num_ws  # 从合成网络获取W的数量
        self.mapping = MappingNetwork(
            z_dim=z_dim,
            c_dim=c_dim,
            w_dim=w_dim,
            num_ws=self.num_ws
        )
        # 缓存中间潜在向量W(用于调试或可视化)
        self.ws = None

    def getws(self):
        """获取当前生成的中间潜在向量W"""
        return self.ws if self.ws is not None else None

    def forward(self, z, c):
        # 步骤1:通过映射网络生成中间潜在向量W
        ws = self.mapping(z, c)  # z: [batch, z_dim], c: [batch, c_dim]
        self.ws = ws.detach()    # 缓存W(不参与梯度计算)
        
        # 步骤2:通过合成网络生成图像
        img = self.synthesis(ws) # ws: [batch, num_ws, w_dim]
        return img               # img: [batch, channels, H, W]

此 Generator 是 StyleGAN2 的核心组件,其设计包含以下关键特征:

解耦潜在空间:通过 W 将噪声 Z 的语义信息分层解耦。

自适应风格控制:每个合成网络层接收不同的 W 子向量,独立控制不同层级的特征。

高分辨率生成:合成网络通常包含上采样块和残差连接,支持生成高清图像。

鉴别器定义

import torch
import numpy as np

class Discriminator(torch.nn.Module):
    def __init__(self,
                 c_dim,           # 条件标签C的维度(如类别标签)
                 img_resolution,  # 输入图像分辨率(如256)
                 img_channels,    # 输入图像通道数(如3表示RGB)
                 channel_base=32768,  # 通道数基数(控制各层通道数)
                 channel_max=512,     # 单层最大通道数限制
                 conv_clamp=256       # 卷积输出值截断范围(防止梯度爆炸)
                 ):
        super().__init__()
        self.c_dim = c_dim
        self.img_resolution = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))  # 分辨率对数(如256→8)
        self.img_channels = img_channels
        self.block_resolutions = [2**i for i in range(self.img_resolution_log2, 2, -1)]  # 处理的分辨率层级(如256,128,...,8)

        # 计算各分辨率对应的通道数(随分辨率降低而增加)
        channels_dict = {
            res: min(channel_base // res, channel_max) 
            for res in self.block_resolutions + [4]  # 包含最终4x4层
        }
        cmap_dim = channels_dict[4]  # 条件映射维度
        if c_dim == 0:
            cmap_dim = 0  # 无条件时禁用条件映射

        # 构建多分辨率判别块
        cur_layer_idx = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res] if res < img_resolution else 0  # 输入通道数
            tmp_channels = channels_dict[res]    # 中间通道数
            out_channels = channels_dict[res//2] # 输出通道数(下一层输入)

            # 创建判别块(如b256, b128,...)
            block = DiscriminatorBlock(
                in_channels, tmp_channels, out_channels,
                resolution=res, first_layer_idx=cur_layer_idx,
                img_channels=img_channels, conv_clamp=conv_clamp
            )
            setattr(self, f'b{res}', block)
            cur_layer_idx += block.num_layers  # 更新全局层索引

        # 条件映射网络(若有条件标签)
        if c_dim > 0:
            self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None)

        # 最终判别尾部网络(4x4分辨率)
        self.b4 = DiscriminatorEpilogue(
            channels_dict[4], cmap_dim=cmap_dim, resolution=4,
            img_channels=img_channels, conv_clamp=conv_clamp
        )

    def forward(self, img, c):
        x = None  # 中间特征(初始为空)
        # 逐级处理图像(从高分辨率到低分辨率)
        for res in self.block_resolutions:
            block = getattr(self, f'b{res}')
            x, img = block(x, img)  # 返回当前层特征和下采样后的

此判别器设计与生成器对称,是StyleGAN2完整对抗训练框架的核心组件之一。
无条件GAN训练:当 c_dim=0 时,判别器仅分析图像内容。
条件GAN训练:通过 C 实现类别控制(如生成指定类别的图像)。
图像质量评估:判别器输出的得分可用于量化生成图像的逼真度。

代码部分参考

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值