参考代码:https://github.com/christiancosgrove/pytorch-spectral-normalization-gan
参考代码:https://github.com/heykeetae/Self-Attention-GAN
参考代码:https://github.com/taki0112/Self-Attention-GAN-Tensorflow
谱归一就是限制W,使他趋于一个分布
谱归一代码部分,可以直接复制上去,调用见下个code:
weight_init = tf.random_normal_initializer(mean=0.0, stddev=0.02)
weight_regularizer = None
def spectral_norm(w, iteration=1):
w_shape = w.shape.as_list()
w = tf.reshape(w, [-1, w_shape[-1]])
#print("w:",w.shape)#w: (48, 64) #w: (1024, 128) w: (2048, 256)
u = tf.get_variable("u", [1, w_shape[-1]], initializer=tf.truncated_normal_initializer(), trainable=False)
u_hat = u
v_hat = None
for i in range(iteration):
"""
power iteration
Usually iteration = 1 will be enough
"""
#print("u_hat:",i,u_hat.shape)#u_hat: 0 (1, 64) u_hat: 0 (1, 128) u_hat: 0 (1, 256)
v_ = tf.matmul(u_hat, tf.transpose(w))
#print("v_",v_.shape)#v_ (1, 48) #v_ (1, 1024) v_ (1, 2048)
v_hat = l2_norm(v_)
#print("v_hat:",v_hat.shape)#v_hat: (1, 48) v_hat: (1, 1024) v_hat: (1, 2048)
u_ = tf.matmul(v_hat, w)
u_hat = l2_norm(u_)
sigma = tf.matmul(tf.matmul(v_hat, w), tf.transpose(u_hat))
#print("sigma",sigma.shape)#sigma (1, 1)
w_norm = w / sigma
with tf.control_dependencies([u.assign(u_hat)]):
w_norm = tf.reshape(w_norm, w_shape)
return w_norm
def l2_norm(v, eps=1e-12):
return v / (tf.reduce_sum(v ** 2) ** 0.5 + eps)
调用谱归一:在你的w范围使用
w = tf.get_default_graph().get_tensor_by_name(self.core.name)
#print("w1:",w)
w = spectral_norm(w)
这是我的范围,我的name_scope在前面
elif self.layer_type == 'transconv2d':
self.core = tf.layers.conv2d_transpose(
self.conditioned, filters, kernel_size, strides, padding,
kernel_initializer=kernel_initializer, name='transconv2d')
w = tf.get_default_graph().get_tensor_by_name(self.core.name)
#print("w1:",w)
w = spectral_norm(w)