【深度学习】生成对抗网络

GAN:原始的理解,没有标签的输入一个网络得到的一个输出,作为生成器
另外一个网络,标记输入,标签以真实图像对应标签1,生成的图像作为0.
目标函数时是生成判别器认为是1的标签图像。

#开发时间:2021/11/8 21:04
#开发内容:
#运行环境:tensroflow>2.0 numpy == 1.19
#备注内容:搭建生成对抗网络
import numpy as np
import tensorflow.keras as ks
import matplotlib.pyplot as plt

import tqdm
L=ks.layers
Latent_dim=100
Image_shape=(28,28,1)
#生成器
generator_net=[L.Input(shape=(Latent_dim,)),
			   L.Dense(256),
			   L.LeakyReLU(alpha=0.2),
			   L.BatchNormalization(momentum=0.8),
			   L.Dense(512),
			   L.LeakyReLU(alpha=0.2),
			   L.BatchNormalization(momentum=0.8),
			   L.Dense(1024),
			   L.LeakyReLU(alpha=0.2),
			   L.BatchNormalization(momentum=0.8),
			   L.Dense(np.prod(Image_shape),activation='tanh'),
			   L.Reshape(Image_shape)]
generator=ks.models.Sequential(generator_net)
generator.summary()

#判别器

discriminator_net = [L.Input(shape=Image_shape),
					 L.Flatten(),
					 L.Dense(512),
					 L.LeakyReLU(alpha=0.2),
					 L.Dense(256),
					 L.LeakyReLU(alpha=0.2),
					 L.Dense(1,activation='sigmoid')]
discriminator=ks.models.Sequential(discriminator_net)
discriminator.summary()


discriminator.compile(loss=ks.losses.binary_crossentropy,
					  optimizer=ks.optimizers.Adam(0.0002,0.5),
					  metrics=['accuracy'])

GAN_net=generator_net+discriminator_net
for layer in discriminator_net:
	layer.trainable=False
GAN=ks.models.Sequential(GAN_net)

GAN.compile(loss=ks.losses.binary_crossentropy,
			optimizer=ks.optimizers.Adam(0.0002,0.5),
			metrics=['accuracy'])
GAN.summary()

def train(batch=30000,batch_size=32):
	(image_set,_),(_,_)=ks.datasets.mnist.load_data()
	#只需要图片不需要标签
	image_set = image_set/127.5-1
	image_set= image_set.reshape(len(image_set),28,28,1)
	#准备batch_size
	valid=np.ones((batch_size))
	fake=np.zeros((batch_size))

	batch_list=tqdm.trange(batch)
	for batch in batch_list:
		idx=np.random.randint(0,image_set.shape[0],batch_size)
		imgs=image_set[idx]

		noise= np.random.normal(0,1,(batch_size,Latent_dim))

		gen_imgs=generator.predict(noise)

		d_state_real=discriminator.train_on_batch(imgs,valid)
		d_state_fake=discriminator.train_on_batch(gen_imgs,fake)
		d_state=0.5*np.add(d_state_real,d_state_fake)

		noise=np.random.normal(0,1,(batch_size,Latent_dim))
		adv_state=GAN.train_on_batch(noise,valid)

		state=f"[D loss:{d_state[0]:.4f} acc:{d_state[1]:.4f}]"\
		      f"[A loss:{adv_state[0]:.4f} acc:{adv_state[1]:.4f}]"
		batch_list.set_postfix(state=state)

train(batch=3000,batch_size=32)

在这里插入图片描述
在这里插入图片描述
GAN的训练在同一轮梯度反传的过程中可以细分为2步,先训练D在训练G;注意不是等所有的D训练好以后,才开始训练G,因为D的训练也需要上一轮梯度反传中G的输出值作为输入。

当训练D的时候,上一轮G产生的图片,和真实图片,直接拼接在一起,作为x。然后根据,按顺序摆放0和1,假图对应0,真图对应1。然后就可以通过,x输入生成一个score(从0到1之间的数),通过score和y组成的损失函数,就可以进行梯度反传了。(我在图片上举的例子是batch = 1,len(y)=2*batch,训练时通常可以取较大的batch)

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

高山仰止景

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值