本例在前文构建了GAN网络之上的主程序
from GANFolder.GAN import Generator,Discriminator
import tensorflow as tf
import numpy as np
import os
from PIL import Image
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
def ParseData(batchData):
feature = {
'width':tf.io.FixedLenFeature([],tf.int64),
'height':tf.io.FixedLenFeature([],tf.int64),
'wmode':tf.io.FixedLenFeature([],tf.int64),
'raw_image':tf.io.FixedLenFeature([],tf.string),
}
example = tf.io.parse_single_example(batchData,feature)
raw_image_uint8 = tf.image.decode_jpeg(example['raw_image'],channels=3)
image_tensor = tf.image.resize(raw_image_uint8,(64,64))
return image_tensor
batchSize = 128
EPOCH = 300000
lr = 0.0002
isTraining = True
dataSets = tf.data.TFRecordDataset(r"E:\PycharmProjects\untitled\GANFolder\k.tfrecord")
dataSets = dataSets.map(ParseData).shuffle(10000).batch(batchSize,drop_remainder=True).repeat()
dataSets_iter = iter(dataSets)
z_dim = 100
generater = Generator()
generater.build(input_shape=[None,z_dim])
generater.summary()
print()
discriminater = Discriminator()
discriminater.build(input_shape=[None,64,64,3])
discriminater.summary()
G_optimizer = tf.keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
D_optimizer = tf.keras.optimizers.Adam(learning_rate=lr,beta_1=0.5)
G_LOSS = 1000000
D_LOSS = 1000000
def real_loss_ones(losigts):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=losigts,labels=tf.ones_like(losigts)))
def fake_loss_zeros(losigts):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=losigts,labels=tf.zeros_like(losigts)))
def d_loss_func(gen,dis,real,fakea,isTraining):
fake_image = gen(fakea,training = isTraining)
fake_logits = dis(fake_image,training =isTraining)
real_logits = dis(real,training =isTraining)
r_loss = real_loss_ones(real_logits)
f_loss = fake_loss_zeros(fake_logits)
loss = r_loss + f_loss
return loss
def g_loss_func(gen,dis,fakea,isTraining):
fake_image = gen(fakea,training =isTraining)
logits = dis(fake_image,training =isTraining)
r_loss = real_loss_ones(logits)
loss = r_loss
return loss
for epoch in range(EPOCH):
real_data = next(dataSets_iter)
fake_data = tf.random.uniform([batchSize,z_dim],minval=-1.,maxval=1.)
if epoch % 5 == 0: #训练5次生成器之后,训练一次判别器
with tf.GradientTape() as d_tape:
d_loss = d_loss_func(generater,discriminater,real_data,fake_data,isTraining)
grad = d_tape.gradient(d_loss,discriminater.trainable_variables)
D_optimizer.apply_gradients(zip(grad,discriminater.trainable_variables))
if d_loss < D_LOSS:
D_LOSS = d_loss
discriminater.save_weights("./discirminaterWeight/d.ckpt")
print("Discriminator Weights has been Refreshed!")
with tf.GradientTape() as g_tape:
g_loss = g_loss_func(generater,discriminater,fake_data,isTraining)
grad = g_tape.gradient(g_loss,generater.trainable_variables)
G_optimizer.apply_gradients(zip(grad,generater.trainable_variables))
if g_loss < G_LOSS:
G_LOSS = g_loss
generater.save_weights("./GeneratorWeight/d.ckpt")
print("Generator Weights has been Refreshed!")
if epoch % 100 == 0:
fake_image_data = tf.random.uniform([100,z_dim],minval=-1.,maxval=1.)
fake_Image = generater(fake_image_data,False)
print("g_loss:",float(g_loss),"d_loss:",float(d_loss))
for index , single_image in enumerate(fake_Image):
single_image = np.uint8(tf.nn.sigmoid(single_image) * 255.)
# print(single_image.shape,single_image.min(),single_image.max())
if index == 99:
new_img = Image.new('RGB',(64,64))
img = Image.fromarray(single_image,'RGB')
new_img.paste(img,(0,0))
new_img.save("./save_imgs/{a}-{b}.png".format(a=epoch,b = index))
print("Epoch {a} 's Images has been saved!".format(a=epoch+1))