训练过程
在原始GAN理论中,应该在生成器之前对鉴别器进行训练。但在实践中,由于鉴别器能更快的训练,因此鉴别器的梯度将逐渐消失。有了Wasserstein损失函数后,可以在任何地方推导梯度,将不必担心评论家相较生成器过于强大。
因此,在WGAN中,对于生成器的每一个训练步骤,评论家都会接受五次训练。为了做到这一点,我们将评论家训练步骤写为一个单独的函数,然后可以循环多次:
for _ in range(self.n_critic):
real_images = next(data_generator)
critic_loss = self.train_critic(real_images, batch_size)
生成器的训练步骤:
self.critic = self.build_critic()
self.critic.trainable = False
self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss = self.wasserstein_loss, optimizer = RMSprop(3e-4))
self.critic.trainable = True
在前面的代码中,通过设置trainable = False
冻结了评论者层,并将其链接到生成器以创建一个新模型并进行编译。之后,我们可以将评论家设置为可训练,这不会影响我们已经编译的模型。
我们使用train_on_batch()
API执行单个训练步骤,该步骤将自动进行前向计算,损失计算,反向传播和权重更新:
g_loss = self.model.train_on_batch(g_input, real_labels)
下图显示了WGAN生成器体系结构:
下图显示了WGAN评论家体系结构:
尽管较原始GAN方面有所改进,但训练WGAN十分困难,并且所产生的图像质量并不比原始GAN更好。接下来,将实现WGAN的变体WGAN-GP,该变体训练速度更快,并产生更清晰的图像。
实现梯度惩罚(WGAN-GP)
正如WGAN作者所承认的那样,权重裁剪并不是实施Lipschitz约束的理想方法。其有两个缺点:网络容量使用不足和梯度爆炸/消失。当我们裁剪权重时,我们也限制了评论家的学习能力。权重裁剪迫使网络仅学习简单特征。因此,神经网络的容量变得未被充分利用。其次,裁剪值需要仔细调整。如果设置得太高,梯度会爆炸,从而违反了Lipschitz约束。如果设置得太低,则随着网络反向传播,梯度将消失。同样,权重裁剪会将梯度推到两个极限值,如下图所示:
因此,提出了梯度惩罚(GP)来代替权重裁剪以强制实施Lipschitz约束,如下所示:
G r a d i e n t p e n a l t y = λ E x ^ [ ( ∥ ∇ x ^ D ( x ^ ) ∥ 2 − 1 ) 2 ] Gradient\ penalty = \lambda E\hat x[(\lVert \nabla _{\hat x}D(\hat x) \rVert_2-1)^2] Gradient penalty=λEx[(∥∇xD(x^)∥2−1)2]
我们将查看方程式中的每个变量,并在代码中实现它们。
我们通常使用 x x x表示真实图像,但是现在方程式中有一个 x ^ \hat x x^。 x ^ \hat x x^是真实图像和伪图像之间的逐点插值。从[0,1]的均匀分布中得出图像比率(epsilon):
epsilon = tf.random.uniform((batch_size,1,1,1))
interpolates = epsilon*real_images + (1-epsilon)*fake_images
根据WGAN-GP论文,就我们的目的而言,我们可以这样理解,因为梯度来自真实图像和伪造图像的混合,因此我们不需要分别计算真实和伪造图像的损失。
∇ x ^ D ( x ^ ) \nabla _{\hat x}D(\hat x) ∇xD(x)项是评论家输出相对于插值的梯度。我们可以再次使用tf.GradientTape()
来获取梯度:
with tf.GradientTape() as gradient_tape:
gradient_tape.watch(interpolates)
critic_interpolates = self.critic(interpolates)
gradient_d = gradient_tape.gradient(critic_interpolates, [interpolates])
下一步是计算L2范数:
∥ ∇ x ^ D ( x ^ ) ∥ 2 \lVert \nabla _{\hat x}D(\hat x) \rVert_2 ∥∇xD(x)∥2
我们对每个值求平方,将它们加在一起,然后求平方根:
grad_loss = tf.square(grad)
grad_loss = tf.reduce_sum(grad_loss, axis=np.arange(1, len(grad_loss.shape)))
grad_loss = tf.sqrt(grad_loss)
在执行tf.reduce_sum()
时,我们排除了轴上的第一维,因为该维是batch大小。惩罚旨在使梯度范数接近1,这是计算梯度损失的最后一步:
grad_loss = tf.reduce_mean(tf.square(grad_loss - 1))
等式中的 λ λ λ是梯度惩罚与其他评论家损失的比率,在本这里中设置为10。现在,我们将所有评论家损失和梯度惩罚添加到反向传播并更新权重:
total_loss = loss_real + loss_fake + LAMBDA * grad_loss
gradients = total_tape.gradient(total_loss, self.critic.variables)
self.optimizer_critic.apply_gradients(zip(gradients, self.critic.variables))
这就是需要添加到WGAN中以使其成为WGAN-GP的所有内容。不过,需要删除以下部分:
-
权重裁剪
-
评论家中的批标准化
梯度惩罚是针对每个输入独立地对评论者的梯度范数进行惩罚。但是,批规范化会随着批处理统计信息更改梯度。为避免此问题,批规范化从评论家中删除。
评论家体系结构与WGAN相同,但不包括批规范化:
以下是经过训练的WGAN-GP生成的样本:
它们看起来清晰漂亮,非常类似于Fashion-MNIST数据集中的样本。训练非常稳定,很快就收敛了!
wgan_and_wgan_gp.py
import tensorflow as tf
from tensorflow.keras import layers, Model
from tensorflow.keras.activations import relu
from tensorflow.keras.models import Sequential, load_model
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import RMSprop, Adam
from tensorflow.keras.metrics import binary_accuracy
import tensorflow_datasets as tfds
import numpy as np
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings(‘ignore’)
print(“Tensorflow”, tf.version)
ds_train, ds_info = tfds.load(‘fashion_mnist’, split=‘train’,shuffle_files=True,with_info=True)
fig = tfds.show_examples(ds_train, ds_info)
batch_size = 64
image_shape = (32, 32, 1)
def preprocess(features):
image = tf.image.resize(features[‘image’], image_shape[:2])
image = tf.cast(image, tf.float32)
image = (image-127.5)/127.5
return image
ds_train = ds_train.map(preprocess)
ds_train = ds_train.shuffle(ds_info.splits[‘train’].num_examples)
ds_train = ds_train.batch(batch_size, drop_remainder=True).repeat()
train_num = ds_info.splits[‘train’].num_examples
train_steps_per_epoch = round(train_num/batch_size)
print(train_steps_per_epoch)
“”"
WGAN
“”"
class WGAN():
def init(self, input_shape):
self.z_dim = 128
self.input_shape = input_shape
losses
self.loss_critic_real = {}
self.loss_critic_fake = {}
self.loss_critic = {}
self.loss_generator = {}
critic
self.n_critic = 5
self.critic = self.build_critic()
self.critic.trainable = False
self.optimizer_critic = RMSprop(5e-5)
build generator pipeline with frozen critic
self.generator = self.build_generator()
critic_output = self.critic(self.generator.output)
self.model = Model(self.generator.input, critic_output)
self.model.compile(loss = self.wasserstein_loss,
optimizer = RMSprop(5e-5))
self.critic.trainable = True
def wasserstein_loss(self, y_true, y_pred):
w_loss = -tf.reduce_mean(y_true*y_pred)
return w_loss
def build_generator(self):
DIM = 128
model = tf.keras.Sequential(name=‘Generator’)
model.add(layers.Input(shape=[self.z_dim]))
model.add(layers.Dense(444*DIM))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(layers.Reshape((4,4,4*DIM)))
model.add(layers.UpSampling2D((2,2), interpolation=“bilinear”))
model.add(layers.Conv2D(2*DIM, 5, padding=‘same’))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(layers.UpSampling2D((2,2), interpolation=“bilinear”))
model.add(layers.Conv2D(DIM, 5, padding=‘same’))
model.add(layers.BatchNormalization())
model.add(layers.ReLU())
model.add(layers.UpSampling2D((2,2), interpolation=“bilinear”))
model.add(layers.Conv2D(image_shape[-1], 5, padding=‘same’, activation=‘tanh’))
return model
def build_critic(self):
DIM = 128
model = tf.keras.Sequential(name=‘critics’)
model.add(layers.Input(shape=self.input_shape))
model.add(layers.Conv2D(1*DIM, 5, strides=2, padding=‘same’))
model.add(layers.LeakyReLU(0.2))
model.add(layers.Conv2D(2*DIM, 5, strides=2, padding=‘same’))
model.add(layers.BatchNormalization())
model.add(layers.LeakyReLU(0.2))
model.add(layers.Conv2D(4*DIM, 5, strides=2, padding=‘same’))
自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。
深知大多数Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!
因此收集整理了一份《2024年Python开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。
既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上Python开发知识点,真正体系化!
由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新
如果你觉得这些内容对你有帮助,可以添加V获取:vip1024c (备注Python)
最后
Python崛起并且风靡,因为优点多、应用领域广、被大牛们认可。学习 Python 门槛很低,但它的晋级路线很多,通过它你能进入机器学习、数据挖掘、大数据,CS等更加高级的领域。Python可以做网络应用,可以做科学计算,数据分析,可以做网络爬虫,可以做机器学习、自然语言处理、可以写游戏、可以做桌面应用…Python可以做的很多,你需要学好基础,再选择明确的方向。这里给大家分享一份全套的 Python 学习资料,给那些想学习 Python 的小伙伴们一点帮助!
👉Python所有方向的学习路线👈
Python所有方向的技术点做的整理,形成各个领域的知识点汇总,它的用处就在于,你可以按照上面的知识点去找对应的学习资源,保证自己学得较为全面。
👉Python必备开发工具👈
工欲善其事必先利其器。学习Python常用的开发软件都在这里了,给大家节省了很多时间。
👉Python全套学习视频👈
我们在看视频学习的时候,不能光动眼动脑不动手,比较科学的学习方法是在理解之后运用它们,这时候练手项目就很适合了。
👉实战案例👈
学python就与学数学一样,是不能只看书不做题的,直接看步骤和答案会让人误以为自己全都掌握了,但是碰到生题的时候还是会一筹莫展。
因此在学习python的过程中一定要记得多动手写代码,教程只需要看一两遍即可。
👉大厂面试真题👈
我们学习Python必然是为了找到高薪的工作,下面这些面试题是来自阿里、腾讯、字节等一线互联网大厂最新的面试资料,并且有阿里大佬给出了权威的解答,刷完这一套面试资料相信大家都能找到满意的工作。
一个人可以走的很快,但一群人才能走的更远。如果你从事以下工作或对以下感兴趣,欢迎戳这里加入程序员的圈子,让我们一起学习成长!
AI人工智能、Android移动开发、AIGC大模型、C C#、Go语言、Java、Linux运维、云计算、MySQL、PMP、网络安全、Python爬虫、UE5、UI设计、Unity3D、Web前端开发、产品经理、车载开发、大数据、鸿蒙、计算机网络、嵌入式物联网、软件测试、数据结构与算法、音视频开发、Flutter、IOS开发、PHP开发、.NET、安卓逆向、云计算
👉大厂面试真题👈
我们学习Python必然是为了找到高薪的工作,下面这些面试题是来自阿里、腾讯、字节等一线互联网大厂最新的面试资料,并且有阿里大佬给出了权威的解答,刷完这一套面试资料相信大家都能找到满意的工作。
一个人可以走的很快,但一群人才能走的更远。如果你从事以下工作或对以下感兴趣,欢迎戳这里加入程序员的圈子,让我们一起学习成长!
AI人工智能、Android移动开发、AIGC大模型、C C#、Go语言、Java、Linux运维、云计算、MySQL、PMP、网络安全、Python爬虫、UE5、UI设计、Unity3D、Web前端开发、产品经理、车载开发、大数据、鸿蒙、计算机网络、嵌入式物联网、软件测试、数据结构与算法、音视频开发、Flutter、IOS开发、PHP开发、.NET、安卓逆向、云计算