风格迁移0-09:stylegan-源码无死角解读(5)-Discriminator网络详解

以下链接是个人关于stylegan所有见解,如有错误欢迎大家指出,我会第一时间纠正,如有兴趣可以加微信:17575010159 相互讨论技术。若是帮助到了你什么,一定要记得点赞奥!因为这是对我最大的鼓励。
风格迁移0-00:stylegan-目录-史上最全:https://blog.csdn.net/weixin_43013761/article/details/100895333

源码注释

在贴出源码注释之前,我想把Discriminator的结构贴出来把,如果大家已经跑过该程序,应该看到过很多次打印了:

D                     Params    OutputShape          WeightShape     
---                   ---       ---                  ---             
images_in             -         (?, 3, 1024, 1024)   -               
labels_in             -         (?, 0)               -               
lod                   -         ()                   -               
FromRGB_lod0          64        (?, 16, 1024, 1024)  (1, 1, 3, 16)   
1024x1024/Conv0       2320      (?, 16, 1024, 1024)  (3, 3, 16, 16)  
1024x1024/Conv1_down  4640      (?, 32, 512, 512)    (3, 3, 16, 32)  
Downscale2D           -         (?, 3, 512, 512)     -               
FromRGB_lod1          128       (?, 32, 512, 512)    (1, 1, 3, 32)   
Grow_lod0             -         (?, 32, 512, 512)    -               
512x512/Conv0         9248      (?, 32, 512, 512)    (3, 3, 32, 32)  
512x512/Conv1_down    18496     (?, 64, 256, 256)    (3, 3, 32, 64)  
Downscale2D_1         -         (?, 3, 256, 256)     -               
FromRGB_lod2          256       (?, 64, 256, 256)    (1, 1, 3, 64)   
Grow_lod1             -         (?, 64, 256, 256)    -               
256x256/Conv0         36928     (?, 64, 256, 256)    (3, 3, 64, 64)  
256x256/Conv1_down    73856     (?, 128, 128, 128)   (3, 3, 64, 128) 
Downscale2D_2         -         (?, 3, 128, 128)     -               
FromRGB_lod3          512       (?, 128, 128, 128)   (1, 1, 3, 128)  
Grow_lod2             -         (?, 128, 128, 128)   -               
128x128/Conv0         147584    (?, 128, 128, 128)   (3, 3, 128, 128)
128x128/Conv1_down    295168    (?, 256, 64, 64)     (3, 3, 128, 256)
Downscale2D_3         -         (?, 3, 64, 64)       -               
FromRGB_lod4          1024      (?, 256, 64, 64)     (1, 1, 3, 256)  
Grow_lod3             -         (?, 256, 64, 64)     -               
64x64/Conv0           590080    (?, 256, 64, 64)     (3, 3, 256, 256)
64x64/Conv1_down      1180160   (?, 512, 32, 32)     (3, 3, 256, 512)
Downscale2D_4         -         (?, 3, 32, 32)       -               
FromRGB_lod5          2048      (?, 512, 32, 32)     (1, 1, 3, 512)  
Grow_lod4             -         (?, 512, 32, 32)     -               
32x32/Conv0           2359808   (?, 512, 32, 32)     (3, 3, 512, 512)
32x32/Conv1_down      2359808   (?, 512, 16, 16)     (3, 3, 512, 512)
Downscale2D_5         -         (?, 3, 16, 16)       -               
FromRGB_lod6          2048      (?, 512, 16, 16)     (1, 1, 3, 512)  
Grow_lod5             -         (?, 512, 16, 16)     -               
16x16/Conv0           2359808   (?, 512, 16, 16)     (3, 3, 512, 512)
16x16/Conv1_down      2359808   (?, 512, 8, 8)       (3, 3, 512, 512)
Downscale2D_6         -         (?, 3, 8, 8)         -               
FromRGB_lod7          2048      (?, 512, 8, 8)       (1, 1, 3, 512)  
Grow_lod6             -         (?, 512, 8, 8)       -               
8x8/Conv0             2359808   (?, 512, 8, 8)       (3, 3, 512, 512)
8x8/Conv1_down        2359808   (?, 512, 4, 4)       (3, 3, 512, 512)
Downscale2D_7         -         (?, 3, 4, 4)         -               
FromRGB_lod8          2048      (?, 512, 4, 4)       (1, 1, 3, 512)  
Grow_lod7             -         (?, 512, 4, 4)       -               
4x4/MinibatchStddev   -         (?, 513, 4, 4)       -               
4x4/Conv              2364416   (?, 512, 4, 4)       (3, 3, 513, 512)
4x4/Dense0            4194816   (?, 512)             (8192, 512)     
4x4/Dense1            513       (?, 1)               (512, 1)        
scores_out            -         (?, 1)               -               
---                   ---       ---                  ---             
Total                 23087249   

下面是源码的注释:

#----------------------------------------------------------------------------
# Discriminator used in the StyleGAN paper.

def D_basic(
    images_in,                          # First input: Images [minibatch, channel, height, width].
    labels_in,                          # Second input: Labels [minibatch, label_size].
    num_channels        = 1,            # Number of input color channels. Overridden based on dataset.
    resolution          = 32,           # Input resolution. Overridden based on dataset.
    label_size          = 0,            # Dimensionality of the labels, 0 if no labels. Overridden based on dataset.
    fmap_base           = 8192,         # Overall multiplier for the number of feature maps.
    fmap_decay          = 1.0,          # log2 feature map reduction when doubling the resolution.
    fmap_max            = 512,          # Maximum number of feature maps in any layer.
    nonlinearity        = 'lrelu',      # Activation function: 'relu', 'lrelu',
    use_wscale          = True,         # Enable equalized learning rate?
    mbstd_group_size    = 4,            # Group size for the minibatch standard deviation layer, 0 = disable.
    mbstd_num_features  = 1,            # Number of features for the minibatch standard deviation layer.
    dtype               = 'float32',    # Data type to use for activations and outputs.
    fused_scale         = 'auto',       # True = fused convolution + scaling, False = separate ops, 'auto' = decide automatically.
    blur_filter         = [1,2,1],      # Low-pass filter to apply when resampling activations. None = no filtering.
    structure           = 'auto',       # 'fixed' = no progressive growing, 'linear' = human-readable, 'recursive' = efficient, 'auto' = select automatically.
    is_template_graph   = False,        # True = template graph constructed by the Network class, False = actual evaluation.
    **_kwargs):                         # Ignore unrecognized keyword args.

    # 在我们的网络中,输入为(?, 3, 1024,1024),即得到的resolution_log2为10
    resolution_log2 = int(np.log2(resolution))
    assert resolution == 2**resolution_log2 and resolution >= 4

    # 通过stage,stage为fmap网络全链接的层数,stage指定的层数不同,求的的fmap不同,
    def nf(stage): return min(int(fmap_base / (2.0 ** (stage * fmap_decay))), fmap_max)

    # 进行模糊操作
    def blur(x): return blur2d(x, blur_filter) if blur_filter else x

    if structure == 'auto': structure = 'linear' if is_template_graph else 'recursive'
    act, gain = {'relu': (tf.nn.relu, np.sqrt(2)), 'lrelu': (leaky_relu, np.sqrt(2))}[nonlinearity]

    # 输入图片,为生成器生成1024分辨率的图像
    images_in.set_shape([None, num_channels, resolution, resolution])
    # 标签
    labels_in.set_shape([None, label_size])

    # 对输入的图片进行格式转换,一般转换为float类型
    images_in = tf.cast(images_in, dtype)
    labels_in = tf.cast(labels_in, dtype)

    # 获取当前lod,可以简单理解为2的lod次方,代表分辨率,我们知道图片最开始输出的是低分辨率的图像,
    # 虽然图像的像素都是1024,倒是经过平滑之后,几乎都看不出来是什么
    lod_in = tf.cast(tf.get_variable('lod', initializer=np.float32(0.0), trainable=False), dtype)
    scores_out = None

    # Building blocks.
    # 一个卷积之后加上一个偏置,然后经过一个激活函数,这里的res控制的是分辨率(不要和真是的分辨率混合),
    # 我所说的分辨率,是吧1024按照2的res次方分割的分辨率.这里输出的格式为RGB
    def fromrgb(x, res): # res = 2..resolution_log2
        with tf.variable_scope('FromRGB_lod%d' % (resolution_log2 - res)):
            return act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=1, gain=gain, use_wscale=use_wscale)))

    # 分开了两个部分,根据res决定分辨率
    def block(x, res): # res = 2..resolution_log2
        with tf.variable_scope('%dx%d' % (2**res, 2**res)):
            if res >= 3: # 8x8 and up
                with tf.variable_scope('Conv0'):
                    x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale)))
                with tf.variable_scope('Conv1_down'):
                    x = act(apply_bias(conv2d_downscale2d(blur(x), fmaps=nf(res-2), kernel=3, gain=gain, use_wscale=use_wscale, fused_scale=fused_scale)))
            else: # 4x4
                if mbstd_group_size > 1:
                    x = minibatch_stddev_layer(x, mbstd_group_size, mbstd_num_features)
                with tf.variable_scope('Conv'):
                    x = act(apply_bias(conv2d(x, fmaps=nf(res-1), kernel=3, gain=gain, use_wscale=use_wscale)))
                with tf.variable_scope('Dense0'):
                    x = act(apply_bias(dense(x, fmaps=nf(res-2), gain=gain, use_wscale=use_wscale)))
                with tf.variable_scope('Dense1'):
                    x = apply_bias(dense(x, fmaps=max(label_size, 1), gain=1, use_wscale=use_wscale))
            return x

    # Fixed structure: simple and efficient, but does not support progressive growing.
    # 简单直接方式,直接进行搭建,分辨率固定不变,即没有分步成长
    if structure == 'fixed':
        x = fromrgb(images_in, resolution_log2)
        for res in range(resolution_log2, 2, -1):
            x = block(x, res)
        scores_out = block(x, 2)

    # Linear structure: simple but inefficient.
    # 从高分辨率开始,逐步进行下采样,每个下采样都会有对应的RGB图像进行输出
    if structure == 'linear':
        img = images_in
        x = fromrgb(img, resolution_log2)
        for res in range(resolution_log2, 2, -1):
            lod = resolution_log2 - res
            x = block(x, res)
            img = downscale2d(img)
            y = fromrgb(img, res - 1)
            with tf.variable_scope('Grow_lod%d' % lod):
                x = tflib.lerp_clip(x, y, lod_in - lod)
        scores_out = block(x, 2)

    # Recursive structure: complex but efficient.
    # 这个没有去了解,有知道的哥们可以告诉我,我会把你的解释写在这个地方
    if structure == 'recursive':
        def cset(cur_lambda, new_cond, new_lambda):
            return lambda: tf.cond(new_cond, new_lambda, cur_lambda)
        def grow(res, lod):
            x = lambda: fromrgb(downscale2d(images_in, 2**lod), res)
            if lod > 0: x = cset(x, (lod_in < lod), lambda: grow(res + 1, lod - 1))
            x = block(x(), res); y = lambda: x
            if res > 2: y = cset(y, (lod_in > lod), lambda: tflib.lerp(x, fromrgb(downscale2d(images_in, 2**(lod+1)), res - 1), lod_in - lod))
            return y()
        scores_out = grow(2, resolution_log2 - 2)

    # Label conditioning from "Which Training Methods for GANs do actually Converge?"
    # 该处似乎为零,所以没有执行
    if label_size:
        with tf.variable_scope('LabelSwitch'):
            scores_out = tf.reduce_sum(scores_out * labels_in, axis=1, keepdims=True)
    # 这里的scores_out维度为(?,1),输出代表的应该是这个图片为真或者为假的概率
    assert scores_out.dtype == tf.as_dtype(dtype)
    scores_out = tf.identity(scores_out, name='scores_out')
    return scores_out

这样的结构真的简单,一句总结,就是输入图片,然后通过一系列的卷积激活,全连接操作,然后得到一个值,这个值就是对应图片图片是否为真是图片的概率值。

到这里,我们对整个网络的分析就已经完成了,下面就会进入我们核心的核心了,那就是损失函数的讲解。如果觉得我写的博客对大家有所帮助,希望大家能给我点点赞,感谢大家一直以来的关注。

  • 18
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

江南才尽,年少无知!

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值