2.Deblur_GAN
1.生成器
def generator_model():
inputs=Input(shape=image_shape)#(256,256,3)
x=ReflectionPadding2D((3,3))(inputs)
x=Conv2D(filters=ngf,kernel_size=(7,7),padding='valid')(x)
x=BatchNormalization()(x)
x=Activation('relu')(x)
n_downsampling=2
for i in range(n_downsampling):
mult=2**i#1,2
x=Conv2D(filters=ngf*mult*2,kernel_size=(3,3),strides=2,padding='same')(x)
x=BatchNormalization()(x)
x=Activation('relu')(x)#(64,64,256)
mult=2**n_downsampling#4
for i in range(n_blocks_gen):#9次残差卷积块
x=res_block(x,ngf*mult,use_dropout=True)
for i in range(n_downsampling):
mult=2**(n_downsampling-i)
x = UpSampling2D()(x)
x = Conv2D(filters=int(ngf * mult / 2), kernel_size=(3, 3), padding='same')(x)
x=BatchNormalization()(x)
x=Activation('relu')(x)#(256,256,64)
x=ReflectionPadding2D((3,3))(x)
x=Conv2D(filters=output_nc,kernel_size=(7,7),padding='valid')(x)#(256,256,3)
x=Activation('tanh')(x)
outputs=Add()([x,inputs])#大残差边
model=Model(inputs=inputs,outputs=outputs,name='generator')
return model
2.判别器
def discriminator_model():
n_layers,use_sigmoid=3,False
inputs=Input(shape=input_shape_discriminator)#(256,256,3)
x=Conv2D(filters=ndf,kernel_size=(4,4),strides=2,padding='same')(inputs)#(128,128,64)
x=LeakyReLU(0.2)(x)
for n in range(n_layers):
ndf_mult=2**n
x=Conv2D(filters=ndf*ndf_mult,kernel_size=(4,4),strides=2,padding='same')(x)
x=BatchNormalization()(x)
x=LeakyReLU(0.2)(x)#(32,32,256)
x=Conv2D(filters=ndf*ndf_mult*2,kernel_size=(4,4),strides=1,padding='same')(x)#(16,16,512)
x=BatchNormalization()(x)
x=LeakyReLU(0.2)(x)
x=Conv2D(filters=1,kernel_size=(4,4),strides=1,padding='same')(x)#(16,16,1)
if use_sigmoid:
x=Activation('sigmoid')(x)
x=Flatten()(x)
x=Dense(1024,activation='tanh')(x)
x=Dense(1,activation='sigmoid')(x)#1
model=Model(inputs=inputs,outputs=x,name='discriminator')
return model
3.判别网络,训练判别器
#评判网络
d=discriminator_model()
d_opt=Adam(lr=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
d.trainable=True
d.compile(optimizer=d_opt,loss=wasserstein_loss)
4.生成判别网络,训练生成器
def generator_containing_discriminator_multiple_outputs(generator,discriminator):
inputs=Input(shape=image_shape)
generated_image=generator(inputs)
outputs=discriminator(generated_image)
model=Model(inputs=inputs,outputs=[generated_image,outputs])
return model
#生成评判网络
d.trainable=False
g=generator_model()
d_on_g=generator_containing_discriminator_multiple_outputs(g,d)
d_on_g_opt=Adam(lr=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
loss=[preceptual_loss,wasserstein_loss]
loss_weights=[100,1]
d_on_g.compile(optimizer=d_on_g_opt,loss=loss,loss_weights=loss_weights)
d.trainable=True
5.全部代码
from keras.layers import Input,Conv2D,Dense,Flatten,Lambda,Dropout,BatchNormalization,LeakyReLU,Add,Activation,UpSampling2D
from keras.models import Model
from keras.utils import conv_utils
from keras.engine import InputSpec,Layer
import keras.backend as K
import tensorflow as tf
import os
import datetime
import click#快速创建命令行
import numpy as np
from keras.callbacks import TensorBoard
from keras.optimizers import Adam
import tqdm#进度条
from PIL import Image
from keras.applications.vgg16 import VGG16
#镜像填充,好像torch有api
def spatial_reflection_2d_padding(x,padding=((1,1),(1,1)),data_format=None):
assert len(padding)==2
assert len(padding[0])==2
assert len(padding[1])==2
if data_format is None:
data_format=K.image_data_format()
if data_format not in {'channels_first','channels_last'}:
raise ValueError('unknow data_format'+str(data_format))
if data_format=='channels_first':
pattern=[[0,0][0,0],list(padding[0]),list(padding[1])]##batch,channe,方向都不补,h对应补,w对应补
else:
pattern=[[0,0],list(padding[0]),list(padding[1]),[0,0]]##batch方向不补,h对应补,w对应补,channe方向不补
return tf.pad(x,pattern,'REFLECT')#reflect镜像填充
#镜像填充做成一个Layer,自动调用call函数
class ReflectionPadding2D(Layer):
def __init__(self,padding=(1,1),data_format=None,**kwargs):
super(ReflectionPadding2D,self).__init__(**kwargs)
self.data_format=conv_utils.normalize_data_format(data_format)
if isinstance(padding,int):#输入是1个数
self.padding=((padding,padding),(padding,padding))
elif hasattr(padding,"__len__"):#输入是2个数,横纵两个方向
if len(padding)!=2:
raise ValueError('padding should have two elements.found:'+str(padding))
height_padding=conv_utils.normalize_tuple(padding[0],2,'1st entry of padding')
width_padding=conv_utils.normalize_tuple(padding[1],2,'2nd entry of padding')
self.padding=(height_padding,width_padding)
else:
raise ValueError('`padding` should be either an int, '
'a tuple of 2 ints '
'(symmetric_height_pad, symmetric_width_pad), '
'or a tuple of 2 tuples of 2 ints '
'((top_pad, bottom_pad), (left_pad, right_pad)). '
'Found: ' + str(padding))
self.input_spec=InputSpec(ndim=4)#规定输入的维度
def compute_output_shape(self,input_shape):#这个计算根本就没有用似乎
if self.data_format=='channels_first':
if input_shape[2] is not None:
rows=input_shape[2]+self.padding[0][0]+self.padding[0][1]#h+上pad+下pad
else:
rows=None
if input_shape[3] is not None:
cols=input_shape[3]+self.padding[1][0]+self.padding[1][1]#w+左pad+右pad
else:
cols=None
return (input_shape[0],input_shape[1],rows,cols)#batch,channe,pad后的h,pad后的w
elif self.data_format=='channels_last':
if input_shape[1] is not None:
rows=input_shape[1]+self.padding[0][0]+self.padding[0][1]
else:
rows=None
if input_shape[2] is not None:
cols=input_shape[2]+self.padding[1][0]+self.padding[1][1]
else:
cols=None
return (input_shape[0],rows,cols,input_shape[3])#batch,pad后的h,pad后的w,channe,
def call(self,inputs):
return spatial_reflection_2d_padding(inputs,padding=self.padding,data_format=self.data_format)
def get_config(self):
config={'padding':self.padding,'data_format':self.data_format}
base_config=super(ReflectionPadding2D,self).get_config()
return dict(list(base_config.items())+list(config.items()))
#残差卷积块
def res_block(inputs,filters,kernel_size=(3,3),strides=(1,1),use_dropout=False):
x=ReflectionPadding2D((1,1))(inputs)
x=Conv2D(filters=filters,kernel_size=kernel_size,strides=strides)(x)
x=BatchNormalization()(x)
x=Activation('relu')(x)
if use_dropout:
x=Dropout(0.5)(x)
x=ReflectionPadding2D((1,1))(x)
x=Conv2D(filters=filters,kernel_size=kernel_size,strides=strides)(x)
x=BatchNormalization()(x)
merged=Add()([inputs,x])
return merged
channel_rate=64
image_shape=(256,256,3)
patch_shape=(channel_rate,channel_rate,3)
ngf=64#生成网络的filters的个数是ngf的整数倍
ndf=64#判别网路的filters的个数是ndf的整数倍
input_nc=3#输入图像channel
output_nc=3#s输出图像channel
input_shape_generator=(256,256,input_nc)#生成网络的输入shape
input_shape_discriminator=(256,256,output_nc)#判别网络的输入shaope
n_blocks_gen=9
def generator_model():
inputs=Input(shape=image_shape)#(256,256,3)
x=ReflectionPadding2D((3,3))(inputs)
x=Conv2D(filters=ngf,kernel_size=(7,7),padding='valid')(x)
x=BatchNormalization()(x)
x=Activation('relu')(x)
n_downsampling=2
for i in range(n_downsampling):
mult=2**i#1,2
x=Conv2D(filters=ngf*mult*2,kernel_size=(3,3),strides=2,padding='same')(x)
x=BatchNormalization()(x)
x=Activation('relu')(x)#(64,64,256)
mult=2**n_downsampling#4
for i in range(n_blocks_gen):#9次残差卷积块
x=res_block(x,ngf*mult,use_dropout=True)
for i in range(n_downsampling):
mult=2**(n_downsampling-i)
x = UpSampling2D()(x)
x = Conv2D(filters=int(ngf * mult / 2), kernel_size=(3, 3), padding='same')(x)
x=BatchNormalization()(x)
x=Activation('relu')(x)#(256,256,64)
x=ReflectionPadding2D((3,3))(x)
x=Conv2D(filters=output_nc,kernel_size=(7,7),padding='valid')(x)#(256,256,3)
x=Activation('tanh')(x)
outputs=Add()([x,inputs])#大残差边
model=Model(inputs=inputs,outputs=outputs,name='generator')
return model
def discriminator_model():
n_layers,use_sigmoid=3,False
inputs=Input(shape=input_shape_discriminator)#(256,256,3)
x=Conv2D(filters=ndf,kernel_size=(4,4),strides=2,padding='same')(inputs)#(128,128,64)
x=LeakyReLU(0.2)(x)
for n in range(n_layers):
ndf_mult=2**n
x=Conv2D(filters=ndf*ndf_mult,kernel_size=(4,4),strides=2,padding='same')(x)
x=BatchNormalization()(x)
x=LeakyReLU(0.2)(x)#(32,32,256)
x=Conv2D(filters=ndf*ndf_mult*2,kernel_size=(4,4),strides=1,padding='same')(x)#(16,16,512)
x=BatchNormalization()(x)
x=LeakyReLU(0.2)(x)
x=Conv2D(filters=1,kernel_size=(4,4),strides=1,padding='same')(x)#(16,16,1)
if use_sigmoid:
x=Activation('sigmoid')(x)
x=Flatten()(x)
x=Dense(1024,activation='tanh')(x)
x=Dense(1,activation='sigmoid')(x)#1
model=Model(inputs=inputs,outputs=x,name='discriminator')
return model
def generator_containing_discriminator_multiple_outputs(generator,discriminator):
inputs=Input(shape=image_shape)
generated_image=generator(inputs)
outputs=discriminator(generated_image)
model=Model(inputs=inputs,outputs=[generated_image,outputs])
return model
#loss
def preceptual_loss(y_true,y_pred):#用vgg对图像提取的特征作为对比计算损失
vgg=VGG16(include_top=False,weights='imagenet',input_shape=image_shape)
loss_model=Model(inputs=vgg.input,outputs=vgg.get_layer('block3_conv3').output)
loss_model.trainable=False#vgg模型不进行训练,只是提取图像特征
return K.mean(K.square(loss_model(y_true)-loss_model(y_pred)))
def wasserstein_loss(y_true,y_pred):
return K.mean(y_true*y_pred)
#save
BASE_DIR='weights/'
def save_all_weights(d,g,epoch_number,current_loss):
now=datetime.datetime.now()
save_dir=os.path.join(BASE_DIR,'{}{}'.format(now.month,now.day))
if not os.path.exists(save_dir):
os.makedirs(save_dir)
g.save_weights(os.path.join(save_dir,'generator_{}_{}.h5'.format(epoch_number,current_loss)),True)
d.save_weights(os.path.join(save_dir,'discriminator_{}.h5'.format(epoch_number)),True)
#load_images, write_log
RESHAPE = (256,256)
def preprocess_image(cv_img):
cv_img = cv_img.resize(RESHAPE)
img = np.array(cv_img)
img = (img - 127.5) / 127.5
return img
def load_images(path,n_images):
if n_images<0:
n_images=float('inf')
A_path,B_path=os.path.join(path,'blurred'),os.path.join(path,'sharp')
all_A_paths=[os.path.join(A_path,f) for f in os.listdir(A_path)]
all_B_paths=[os.path.join(B_path,f) for f in os.listdir(B_path)]
images_A,images_B=[],[]
images_A_paths,images_B_paths=[],[]
for path_A,path_B in zip(all_A_paths,all_B_paths):
img_A=Image.open(path_A)
img_B=Image.open(path_B)
images_A.append(preprocess_image(img_A))
images_B.append(preprocess_image(img_B))
images_A_paths.append(path_A)
images_B_paths.append(path_B)
if len(images_A)>n_images-1:break
return {
'A': np.array(images_A),
'A_paths': np.array(images_A_paths),
'B': np.array(images_B),
'B_paths': np.array(images_B_paths)
}
def write_log(callback, names, logs, batch_no):
"""
Util to write callback for Keras training
"""
for name, value in zip(names, logs):
summary = tf.Summary()
summary_value = summary.value.add()
summary_value.simple_value = value
summary_value.tag = name
callback.writer.add_summary(summary, batch_no)
callback.writer.flush()
#train
def train_multiple_outputs(n_images,batch_size,epoch_num,critic_updates=5):
data=load_images('./blurred_sharp',n_images)
y_train,x_train=data['B'],data['A']
#评判网络
d=discriminator_model()
d_opt=Adam(lr=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
d.trainable=True
d.compile(optimizer=d_opt,loss=wasserstein_loss)
#生成评判网络
d.trainable=False
g=generator_model()
d_on_g=generator_containing_discriminator_multiple_outputs(g,d)
d_on_g_opt=Adam(lr=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
loss=[preceptual_loss,wasserstein_loss]
loss_weights=[100,1]
d_on_g.compile(optimizer=d_on_g_opt,loss=loss,loss_weights=loss_weights)
d.trainable=True
output_true_batch,output_false_batch=np.ones((batch_size,1)),-np.ones((batch_size,1))#用于loss计算
log_path='./logs'
tensorboard_callback=TensorBoard(log_path)
for epoch in tqdm.tqdm(range(epoch_num)):
permutated_indexs=np.random.permutation(x_train.shape[0])#打乱顺序
d_losses=[]
d_on_g_losses=[]
for index in range(int(x_train.shape[0]/batch_size)):#为一个epoch生成数据,里面是一个batch的操作
batch_indexes=permutated_indexs[index*batch_size:(index+1)*batch_size]
image_blur_batch=x_train[batch_indexes]
image_full_batch=y_train[batch_indexes]
generated_images=g.predict(x=image_blur_batch,batch_size=batch_size)#生成batch_size个图
for _ in range(critic_updates):#计算5次,这个有什么意义?
d_loss_real=d.train_on_batch(image_full_batch,output_true_batch)
d_loss_fake=d.train_on_batch(generated_images,output_false_batch)
d_loss=0.5*np.add(d_loss_fake,d_loss_real)
d_losses.append(d_loss)
d.trainable=False
d_on_g_loss=d_on_g.train_on_batch(image_blur_batch,[image_full_batch,output_true_batch])
d_on_g_losses.append(d_on_g_loss)
d.trainable=True
print(np.mean(d_losses),np.mean(d_on_g_losses))#一个epoch的损失均值
with open('log.txt','a+')as f:
f.write('{}-{}-{}\n'.format(epoch,np.mean(d_losses),np.mean(d_on_g_losses)))
save_all_weights(d,g,epoch,int(np.mean(d_on_g_losses)))
#@click.command()
#@click.option('--n_images', default=10, help='Number of images to load for training')
#@click.option('--batch_size', default=1, help='Size of batch')
#@click.option('--log_dir', required=True, help='Path to the log_dir for Tensorboard')
#@click.option('--epoch_num', default=1, help='Number of epochs for training')
#@click.option('--critic_updates', default=1, help='Number of discriminator training')
def train_command(n_images, batch_size, epoch_num, critic_updates):
return train_multiple_outputs(n_images, batch_size, epoch_num, critic_updates)
if __name__ == '__main__':
train_command(10,1,1,1)
6.参考
[1]https://github.com/RaphaelMeudec/deblur-gan