GAN:原始的理解,没有标签的输入一个网络得到的一个输出,作为生成器
另外一个网络,标记输入,标签以真实图像对应标签1,生成的图像作为0.
目标函数时是生成判别器认为是1的标签图像。
#开发时间:2021/11/8 21:04
#开发内容:
#运行环境:tensroflow>2.0 numpy == 1.19
#备注内容:搭建生成对抗网络
import numpy as np
import tensorflow.keras as ks
import matplotlib.pyplot as plt
import tqdm
L=ks.layers
Latent_dim=100
Image_shape=(28,28,1)
#生成器
generator_net=[L.Input(shape=(Latent_dim,)),
L.Dense(256),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(512),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(1024),
L.LeakyReLU(alpha=0.2),
L.BatchNormalization(momentum=0.8),
L.Dense(np.prod(Image_shape),activation='tanh'),
L.Reshape(Image_shape)]
generator=ks.models.Sequential(generator_net)
generator.summary()
#判别器
discriminator_net = [L.Input(shape=Image_shape),
L.Flatten(),
L.Dense(512),
L.LeakyReLU(alpha=0.2),
L.Dense(256),
L.LeakyReLU(alpha=0.2),
L.Dense(1,activation='sigmoid')]
discriminator=ks.models.Sequential(discriminator_net)
discriminator.summary()
discriminator.compile(loss=ks.losses.binary_crossentropy,
optimizer=ks.optimizers.Adam(0.0002,0.5),
metrics=['accuracy'])
GAN_net=generator_net+discriminator_net
for layer in discriminator_net:
layer.trainable=False
GAN=ks.models.Sequential(GAN_net)
GAN.compile(loss=ks.losses.binary_crossentropy,
optimizer=ks.optimizers.Adam(0.0002,0.5),
metrics=['accuracy'])
GAN.summary()
def train(batch=30000,batch_size=32):
(image_set,_),(_,_)=ks.datasets.mnist.load_data()
#只需要图片不需要标签
image_set = image_set/127.5-1
image_set= image_set.reshape(len(image_set),28,28,1)
#准备batch_size
valid=np.ones((batch_size))
fake=np.zeros((batch_size))
batch_list=tqdm.trange(batch)
for batch in batch_list:
idx=np.random.randint(0,image_set.shape[0],batch_size)
imgs=image_set[idx]
noise= np.random.normal(0,1,(batch_size,Latent_dim))
gen_imgs=generator.predict(noise)
d_state_real=discriminator.train_on_batch(imgs,valid)
d_state_fake=discriminator.train_on_batch(gen_imgs,fake)
d_state=0.5*np.add(d_state_real,d_state_fake)
noise=np.random.normal(0,1,(batch_size,Latent_dim))
adv_state=GAN.train_on_batch(noise,valid)
state=f"[D loss:{d_state[0]:.4f} acc:{d_state[1]:.4f}]"\
f"[A loss:{adv_state[0]:.4f} acc:{adv_state[1]:.4f}]"
batch_list.set_postfix(state=state)
train(batch=3000,batch_size=32)
GAN的训练在同一轮梯度反传的过程中可以细分为2步,先训练D在训练G;注意不是等所有的D训练好以后,才开始训练G,因为D的训练也需要上一轮梯度反传中G的输出值作为输入。
当训练D的时候,上一轮G产生的图片,和真实图片,直接拼接在一起,作为x。然后根据,按顺序摆放0和1,假图对应0,真图对应1。然后就可以通过,x输入生成一个score(从0到1之间的数),通过score和y组成的损失函数,就可以进行梯度反传了。(我在图片上举的例子是batch = 1,len(y)=2*batch,训练时通常可以取较大的batch)