写GAN网络的测试代码参考博客
遇到的问题:
我在代码: generator.load_weights('D:/imageRegistration/languageGan/MyGAN/models/G_model/450.h5',by_name = True) 中加入了by_name=True 这一句,测试的过程中,测试生成器时,一直生成的还是噪声图,但是在输出的图片文件夹中,输出的图片中数字的效果还是很好的,我把这一句去掉之后,测试生成器,生成的图片效果就与图片中的效果差不多。
keras源码engine中toplogy.py定义了加载权重的函数:
load_weights(self, filepath, by_name=False)
其中默认by_name为False,这时候加载权重按照网络拓扑结构加载,适合直接使用keras中自带的网络模型,如VGG16、VGG19/resnet50等
若将by_name改为True则加载权重按照layer的name进行,layer的name相同时加载权重,适合用于改变了。模型的相关结构或增加了节点但利用了原网络的主体结构情况下使用。
引用博客
加载模型
load_weights()
load_model()
参考博客
# -*- coding: utf-8 -*-
"""
Created on Sat May 22 15:40:12 2021
@author: LiMeng
"""
from keras.datasets import mnist
from keras.layers import Dense, Dropout, Input
from keras.models import Model,Sequential
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam
from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
#%matplotlib inline
#from google.colab import drive
print_dir = "D:/imageRegistration/languageGan/MyGAN/printImage"
model_dir_G = "D:/imageRegistration/languageGan/MyGAN/models/G_model"
model_dir_D = "D:/imageRegistration/languageGan/MyGAN/models/D_model"
# Load the dataset
def load_data():
(x_train, y_train), (_, _) = mnist.load_data()
x_train = (x_train.astype(np.float32) - 127.5)/127.5
# Convert shape from (60000, 28, 28) to (60000, 784)
x_train = x_train.reshape(60000, 784)
return (x_train, y_train)
X_train, y_train = load_data()
print(X_train.shape, y_train.shape)
def build_generator():
model = Sequential()
model.add(Dense(units=256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(units=512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(units=1024))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(units=784, activation='tanh'))
model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return model
generator = build_generator()
generator.summary()
"""
dropout是指在深度学习网络的训练过程中,对于神经网络单元,按照一定的概率将其暂时从网络中丢弃。注意是暂时,对于随机梯度下降来说,由于是随机丢弃,故而每一个mini-batch都在训练不同的网络。
dropout是CNN中防止过拟合提高效果的一个大杀器
"""
def build_discriminator():
model = Sequential()
model.add(Dense(units=1024 ,input_dim=784))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(units=512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(units=256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.3))
model.add(Dense(units=1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return model
#discriminator = build_discriminator()
#discriminator.summary()
def build_GAN(discriminator, generator):
discriminator.trainable=False
GAN_input = Input(shape=(100,))
x = generator(GAN_input)
GAN_output= discriminator(x)
GAN = Model(inputs=GAN_input, outputs=GAN_output)
GAN.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
return GAN
#GAN = build_GAN(discriminator, generator)
#GAN.summary()
def draw_images(generator, epoch, examples=25, dim=(5,5), figsize=(10,10)):
noise= np.random.normal(loc=0, scale=1, size=[examples, 100])
generated_images = generator.predict(noise)
generated_images = generated_images.reshape(25,28,28)
plt.figure(figsize=figsize)
for i in range(generated_images.shape[0]):
plt.subplot(dim[0], dim[1], i+1)
plt.imshow(generated_images[i], interpolation='nearest', cmap='Greys')
plt.axis('off')
plt.tight_layout()
plt.savefig('D:/imageRegistration/languageGan/MyGAN/printImage/Generated_images%d.png' %epoch)
def train_GAN(epochs=1, batch_size=128,sample_interval = 50):
#Loading the data
X_train, y_train = load_data()
# Creating GAN
generator = build_generator()
discriminator = build_discriminator()
GAN = build_GAN(discriminator, generator)
# print("\nepochs %d\n" %epochs)
for i in range(1, epochs+1):
print("\nEpoch %d\n" %i)
for _ in tqdm(range(batch_size)):
# Generate fake images from random noiset
"""
np.random.normal生成正态分布:均值为0,方差为1,维度为(batch_size,100)--->(128,100)
generator生成器的输入是128张100维的数据 fake_images即为generator生成的虚假图像
np.random.randint(low,high,size,dtype) array = np.random.randint(0, X_train.shape[0], batch_size)
生成128维(行向量) 0-60000的数字,X_train.shape[0]为 60000,
"""
noise= np.random.normal(0,1, (batch_size, 100))
fake_images = generator.predict(noise)
"""
#画出虚假图像
plt.imshow(fake_images)
plt.title("fake_images")
"""
# Select a random batch of real images from MNIST
real_images = X_train[np.random.randint(0, X_train.shape[0], batch_size)]
"""
#画出真实图像
plt.imshow(real_images)
plt.title("real_images")
"""
# Labels for fake and real images
label_fake = np.zeros(batch_size)
label_real = np.ones(batch_size)
# Concatenate fake and real images
"""
X = np.concatenate([fake_images, real_images])
y = np.concatenate([label_fake, label_real])
"""
# Train the discriminator
discriminator.trainable=True
"""
discriminator.train_on_batch(X, y)
"""
d_loss_real = discriminator.train_on_batch(real_images, label_real)
d_loss_fake = discriminator.train_on_batch(fake_images,label_fake)
d_loss = 0.5 * np.add(d_loss_real,d_loss_fake)
print('\n')
print(' ---d_loss_real, d_loss_fake, d_loss: ',d_loss_real,d_loss_fake,d_loss)
# Train the generator/chained GAN model (with frozen weights in discriminator)
discriminator.trainable=False
"""
GAN.train_on_batch(noise, label_real)
"""
g_loss = GAN.train_on_batch(noise, label_real)
print(' ---g_loss\n',g_loss)
"""
print("%d [D loss :%f, acc: %.2f] [G loss: %f]" , i, d_loss[0], 100*d_loss[1], g_loss)
"""
if i % sample_interval == 0:
generator.save('D:/imageRegistration/languageGan/MyGAN/models/G_model/%d.h5'%i)
discriminator.save('D:/imageRegistration/languageGan/MyGAN/models/D_model/%d.h5'%i)
#Draw generated images every 15 epoches
if i == 1 or i % 10 == 0:
draw_images(generator, i)
def test(gen_number = 10, save=False):
#test# generator.load_weights('D:/imageRegistration/languageGan/MyGAN/models/G_model/450.h5')
#discriminator.load_weights("D:/image registration/language gan/My GAN/models/G_model/2.h5",by_name=True)
noise_test = np.random.normal(0,1,(1,100))
fake_images = generator.predict(noise_test)
# fake_images = 0.5*fake_images + 0.5
fake_images = fake_images.reshape(28,28)
plt.imshow(fake_images)
#train_GAN(epochs=10000, batch_size=128)
test()