"""条件GAN的实现:
基于MNIST数据集"""
import os
os.environ["TF_CPP_MIN_LOG_LEVEL"]="2"
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Activation,BatchNormalization,Concatenate,Dense,Embedding,Flatten,Input,Multiply,Reshape
from tensorflow.keras.layers import Conv2D,Conv2DTranspose
from tensorflow.keras.models import Model,Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
#指定图像尺寸形状
img_rows=28
img_cols=28
channels=1
img_shape=(img_rows,img_cols,channels)
#噪声向量维度
z_dim=100
#总类别
num_classes=10
#生成器的构建
def build_generator(z_dim):
model=Sequential()
model.add(Dense(256*7*7,input_dim=z_dim))
model.add(Reshape((7,7,256)))
model.add(Conv2DTranspose(128,kernel_size=3,strides=2,padding="same"))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Conv2DTranspose(64,kernel_size=3,strides=1,padding="same"))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Conv2DTranspose(1,kernel_size=3,strides=2,padding="same"))
model.add(Activation("tanh"))
return model
#构建条件GAN生成器
def build_cgan_generator(z_dim):
z=Input(shape=(z_dim,))
label=Input(shape=(1,),dtype="int32")
#标签变为与z_dim等维度的词向量
label_embedding=Embedding(num_classes,z_dim,input_length=1)(label)
label_embedding=Flatten()(label_embedding)
#标签嵌入z
joined_representation=Multiply()([z,label_embedding])
generator=build_generator(z_dim)
condition_img=generator(joined_representation)
return Model([z,label],condition_img)
#鉴别器的构建
def build_discriminator(img_shape):
model=Sequential()
model.add(Conv2D(64,kernel_size=3,strides=2,input_shape=(img_shape[0],img_shape[1],img_shape[2]+1),padding="same"))
model.add(Activation("relu"))
model.add(Conv2D(64,kernel_size=3,strides=2,padding="same"))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Conv2D(128,kernel_size=3,padding="same"))
model.add(BatchNormalization())
model.add(Activation("relu"))
model.add(Flatten())
model.add(Dense(1,activation="sigmoid"))
return model
def build_cgan_discriminator(img_shape):
img=Input(shape=img_shape)
label=Input(shape=(1,),dtype="int32")
label_embedding=Embedding(num_classes,np.prod(img_shape),input_length=1)(label)
label_embedding=Flatten()(label_embedding)
label_embedding=Reshape(img_shape)(label_embedding)
concatenated=Concatenate(axis=-1)([img,label_embedding])
discriminator=build_discriminator(img_shape)
classification=discriminator(concatenated)
return Model([img,label],classification)
#整个GAN模型的搭建
def build_cgan(generator,discriminator):
z=Input(shape=(z_dim,))
label=Input(shape=(1,))
img=generator([z,label])
classification=discriminator([img,label])
return Model([z,label],classification)
# print(build_cgan(build_cgan_generator(z_dim),build_cgan_discriminator(img_shape)).summary())
#输出样本图像
def sample_images(img_grid_row=2,img_grid_col=5):
z=np.random.normal(0,1,(img_grid_col*img_grid_row,z_dim))
labels=np.arange(0,10).reshape(-1,1)
gen_images=generator.predict([z,labels])
gen_images=0.5*gen_images+0.5
fig,axs=plt.subplots(img_grid_row,img_grid_col,figsize=(10,4),sharex=True,sharey=True)
cnt=0
for i in range(img_grid_row):
for j in range(img_grid_col):
axs[i,j].imshow(gen_images[cnt,:,:,0],cmap="gray")
axs[i,j].axis("off")
axs.set_title("Digit: %d"%labels[cnt])
cnt+=1
plt.show()
#构建鉴别器
discriminator=build_cgan_discriminator(img_shape)
discriminator.compile(loss="binary_crossentropy",optimizer=Adam(),metrics=["acc"])
generator=build_cgan_generator(z_dim)
discriminator.trainable=False
cgan=build_cgan(generator,discriminator)
cgan.compile(loss="binary_crossentropy",optimizer=Adam())
#模型训练
accuracies=[]
losses=[]
def train(iterations,batch_size,sample_interval):
(X_train,y_train),(_,_)=mnist.load_data()
#数据预处理
X_train=X_train/127.5-1.
X_train=np.expand_dims(X_train,axis=3)
real=np.ones((batch_size,1))
fake=np.zeros((batch_size,1))
for iteration in range(iterations):
idx=np.random.randint(0,X_train.shape[0],batch_size)
imgs,labels=X_train[idx],y_train[idx]
z=np.random.normal(0,1,(batch_size,z_dim))
gen_img=generator([z,labels])
discriminator.trainable=True
d_loss_real=discriminator.train_on_batch([imgs,labels],real)
d_loss_fake=discriminator.train_on_batch([gen_img,labels],fake)
d_loss=0.5*np.add(d_loss_real,d_loss_fake)
discriminator.trainable=False
z=np.random.normal(0,1,(batch_size,z_dim))
labels=np.random.randint(0,num_classes,(batch_size)).reshape(-1,1)
g_loss=cgan.train_on_batch([z,labels],fake)
if(iterations+1)%sample_interval==0:
print("%d [D loss: %f,acc.: %.2f] [G loss: %f]"%(iteration+1,d_loss[0],d_loss[1]*100,g_loss))
losses.append((d_loss[0],g_loss))
accuracies.append(d_loss[1]*100)
sample_images()
iterations=12000
batch_size=32
sample_interval=1000
train(iterations,batch_size,sample_interval)
基于Keras实现CGAN用于手写数字生成
最新推荐文章于 2024-06-21 01:41:49 发布
该博客介绍了如何使用TensorFlow构建条件生成对抗网络(CGAN),并基于MNIST数据集进行训练。首先定义了生成器和鉴别器的结构,接着详细阐述了CGAN的构建过程,包括条件生成器和条件鉴别器的建立。在训练部分,展示了训练循环、损失计算和样本生成。最后,模型训练并输出了生成的数字图像。
摘要由CSDN通过智能技术生成