GAN网络对于没有交集的数据集在训练上表现会不好,所以WGAN解决这个问题,加入了梯度惩罚
from GANFolder.GAN import Generator,Discriminator
import tensorflow as tf
import numpy as np
import os
from PIL import Image
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def ParseData(batchData):
feature = {
'width':tf.io.FixedLenFeature([],tf.int64),
'height':tf.io.FixedLenFeature([],tf.int64),
'wmode':tf.io.FixedLenFeature([],tf.int64),
'raw_image':tf.io.FixedLenFeature([],tf.string),
}
example = tf.io.parse_single_example(batchData,feature)
raw_image_uint8 = tf.image.decode_jpeg(example['raw_image'],channels=3)
image_tensor = tf.image.resize(raw_image_uint8,(64,64))
return image_tensor
batchSize = 128
EPOCH = 300000
lr = 0.0002
isTraining = True
dataSets = tf.data.TFRecordDataset(r"E:\PycharmProjects\untitled\GANFolder\Duke.tfrecord")
dataSets = dataSets.map(ParseData).shuffle(10000).batch(batchSize,drop_remainder=True).repeat()
dataSets_iter = iter(dataSets)
z_dim = 100
generater = Generator()
generater.build(input_shape=[None,z_dim])
generater.summary()
print()
discriminater = Discriminator()
discriminater.build(input_shape=[None,64,64,3])
discriminater.summary()
G_optimizer = tf.keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
D_optimizer = tf.keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
G_LOSS = 1000000
D_LOSS = 1000000
def real_loss_ones(losigts):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=losigts,labels=tf.ones_like(losigts)))
def fake_loss_zeros(losigts):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=losigts,labels=tf.zeros_like(losigts)))
def gradient_penalty(dis,real,fakea):
# t = tf.random.uniform(real.shape)
t = tf.random.uniform([real.shape[0],1,1,1])
t = tf.broadcast_to(t,real.shape)
inter = t * real + (1-t) * fakea
with tf.GradientTape() as tape:
tape.watch(inter)
loss_logits = discriminater(inter)
grad = tape.gradient(loss_logits,inter)
grad = tf.reshape(grad,[grad.shape[0],-1])
l2_norm = tf.norm(grad,axis=1)
gp = tf.reduce_mean((l2_norm - 1) ** 2)
return gp
def d_loss_func(gen,dis,real,fakea,isTraining):
fake_image = gen(fakea,training = isTraining)
fake_logits = dis(fake_image,training =isTraining)
real_logits = dis(real,training =isTraining)
r_loss = real_loss_ones(real_logits)
f_loss = fake_loss_zeros(fake_logits)
#梯度惩罚,并乘以一个超参数作为系数,控制惩罚力度
GP = gradient_penalty(discriminater,real,fake_image)
loss = r_loss + f_loss + 1.0 * GP
return loss
def g_loss_func(gen,dis,fakea,isTraining):
fake_image = gen(fakea,training =isTraining)
logits = dis(fake_image,training =isTraining)
r_loss = real_loss_ones(logits)
loss = r_loss
return loss
for epoch in range(EPOCH):
real_data = next(dataSets_iter)
fake_data = tf.random.uniform([batchSize,z_dim],minval=-1.,maxval=1.)
if epoch % 4 == 0: #训练5次生成器之后,训练一次判别器
with tf.GradientTape() as d_tape:
d_loss = d_loss_func(generater,discriminater,real_data,fake_data,isTraining)
grad = d_tape.gradient(d_loss,discriminater.trainable_variables)
D_optimizer.apply_gradients(zip(grad,discriminater.trainable_variables))
if d_loss < D_LOSS:
D_LOSS = d_loss
discriminater.save_weights("./discirminaterWeight/duwei.ckpt")
print("Discriminator Weights has been Refreshed!")
with tf.GradientTape() as g_tape:
g_loss = g_loss_func(generater,discriminater,fake_data,isTraining)
grad = g_tape.gradient(g_loss,generater.trainable_variables)
G_optimizer.apply_gradients(zip(grad,generater.trainable_variables))
if g_loss < G_LOSS:
G_LOSS = g_loss
generater.save_weights("./GeneratorWeight/duwei.ckpt")
print("Generator Weights has been Refreshed!")
if epoch % 100 == 0:
fake_image_data = tf.random.uniform([100,z_dim],minval=-1.,maxval=1.)
fake_Image = generater(fake_image_data,False)
print("g_loss:",float(g_loss),"d_loss:",float(d_loss))
for index , single_image in enumerate(fake_Image):
single_image = np.uint8(tf.nn.sigmoid(single_image) * 255.)
# print(single_image.shape,single_image.min(),single_image.max())
if index == 99:
new_img = Image.new('RGB',(64,64))
img = Image.fromarray(single_image,'RGB')
new_img.paste(img,(0,0))
new_img.save("./save_imgs/{a}-{b}.png".format(a=epoch,b = index))
print("Epoch {a} 's Images has been saved!".format(a=epoch+1))