PassGAN子函数分析(1)
- lib.ops.linear.Linear
- tf.reshape
- tf.transpose
lib.ops.linear.Linear()
该函数是原代码里面自带的,具体内容如下:
import tflib as lib
import numpy as np
import tensorflow as tf
_default_weightnorm = False
def enable_default_weightnorm():
global _default_weightnorm
_default_weightnorm = True
def disable_default_weightnorm():
global _default_weightnorm
_default_weightnorm = False
_weights_stdev = None
def set_weights_stdev(weights_stdev):
global _weights_stdev
_weights_stdev = weights_stdev
def unset_weights_stdev():
global _weights_stdev
_weights_stdev = None
#以上主要是定义了两个关键得变量,一个是weights_stdev,另一个是_default_weightnorm
以上主要是定义了两个关键得变量,一个是weights_stdev,另一个是_default_weightnorm,这两个变量在后面的函数中具有非常重要的意义。
下面是函数Linear函数的定义:
def Linear(
name,
input_dim,
output_dim,
inputs,
biases=True,
initialization=None,
weightnorm=None,
gain=1.
):
"""
initialization: None, `lecun`, 'glorot', `he`, 'glorot_he', `orthogonal`, `("uniform", range)`
"""
with tf.name_scope(name) as scope:
def uniform(stdev, size):
if _weights_stdev is not None:
stdev = _weights_stdev
return np.random.uniform(
low=-stdev * np.sqrt(3),
high=stdev * np.sqrt(3),
size=size
).astype('float32')
# 从一个均匀分布[low,high)中随机采样,
# 注意定义域是左闭右开,即包含low,不包含high,形状大小按照size来
if initialization == 'lecun':# and input_dim != output_dim):
# disabling orth. init for now because it's too slow
weight_values = uniform(
np.sqrt(1./input_dim),
(input_dim, output_dim)
)
elif initialization == 'glorot' or (initialization == None):
weight_values = uniform(
np.sqrt(2./(input_dim+output_dim)),
(input_dim, output_dim)
)
elif initialization == 'he':
weight_values = uniform(
np.sqrt(2./input_dim),
(input_dim, output_dim)
)
elif initialization == 'glorot_he'