2.Deblur_GAN

1.生成器

def generator_model():
    inputs=Input(shape=image_shape)#(256,256,3)
    x=ReflectionPadding2D((3,3))(inputs)
    x=Conv2D(filters=ngf,kernel_size=(7,7),padding='valid')(x)
    x=BatchNormalization()(x)
    x=Activation('relu')(x)
    n_downsampling=2
    for i in range(n_downsampling):
        mult=2**i#1,2
        x=Conv2D(filters=ngf*mult*2,kernel_size=(3,3),strides=2,padding='same')(x)
        x=BatchNormalization()(x)
        x=Activation('relu')(x)#(64,64,256)
    mult=2**n_downsampling#4
    for i in range(n_blocks_gen):#9次残差卷积块
        x=res_block(x,ngf*mult,use_dropout=True)
    for i in range(n_downsampling):
        mult=2**(n_downsampling-i)
        x = UpSampling2D()(x)
        x = Conv2D(filters=int(ngf * mult / 2), kernel_size=(3, 3), padding='same')(x)
        x=BatchNormalization()(x)
        x=Activation('relu')(x)#(256,256,64)
    x=ReflectionPadding2D((3,3))(x)
    x=Conv2D(filters=output_nc,kernel_size=(7,7),padding='valid')(x)#(256,256,3)
    x=Activation('tanh')(x)
    outputs=Add()([x,inputs])#大残差边
    model=Model(inputs=inputs,outputs=outputs,name='generator')
    return model

2.判别器

def discriminator_model():
    n_layers,use_sigmoid=3,False
    inputs=Input(shape=input_shape_discriminator)#(256,256,3)
    x=Conv2D(filters=ndf,kernel_size=(4,4),strides=2,padding='same')(inputs)#(128,128,64)
    x=LeakyReLU(0.2)(x)
    for n in range(n_layers):
        ndf_mult=2**n
        x=Conv2D(filters=ndf*ndf_mult,kernel_size=(4,4),strides=2,padding='same')(x)
        x=BatchNormalization()(x)
        x=LeakyReLU(0.2)(x)#(32,32,256)
    x=Conv2D(filters=ndf*ndf_mult*2,kernel_size=(4,4),strides=1,padding='same')(x)#(16,16,512)
    x=BatchNormalization()(x)
    x=LeakyReLU(0.2)(x)
    x=Conv2D(filters=1,kernel_size=(4,4),strides=1,padding='same')(x)#(16,16,1)
    if use_sigmoid:
        x=Activation('sigmoid')(x)
    x=Flatten()(x)
    x=Dense(1024,activation='tanh')(x)
    x=Dense(1,activation='sigmoid')(x)#1
    model=Model(inputs=inputs,outputs=x,name='discriminator')
    return model

3.判别网络,训练判别器

 #评判网络
    d=discriminator_model()
    d_opt=Adam(lr=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
    d.trainable=True
    d.compile(optimizer=d_opt,loss=wasserstein_loss)

4.生成判别网络,训练生成器

def generator_containing_discriminator_multiple_outputs(generator,discriminator):
    inputs=Input(shape=image_shape)
    generated_image=generator(inputs)
    outputs=discriminator(generated_image)
    model=Model(inputs=inputs,outputs=[generated_image,outputs])
    return model
#生成评判网络
d.trainable=False
g=generator_model()
d_on_g=generator_containing_discriminator_multiple_outputs(g,d)
d_on_g_opt=Adam(lr=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
loss=[preceptual_loss,wasserstein_loss]
loss_weights=[100,1]
d_on_g.compile(optimizer=d_on_g_opt,loss=loss,loss_weights=loss_weights)
d.trainable=True

5.全部代码

from keras.layers import Input,Conv2D,Dense,Flatten,Lambda,Dropout,BatchNormalization,LeakyReLU,Add,Activation,UpSampling2D
from keras.models import Model
from keras.utils import conv_utils
from keras.engine import InputSpec,Layer
import keras.backend as K
import tensorflow as tf
import os
import datetime
import click#快速创建命令行
import numpy as np
from keras.callbacks import TensorBoard
from keras.optimizers import Adam
import tqdm#进度条
from PIL import Image
from keras.applications.vgg16 import VGG16
#镜像填充,好像torch有api
def spatial_reflection_2d_padding(x,padding=((1,1),(1,1)),data_format=None):
    assert len(padding)==2
    assert len(padding[0])==2
    assert len(padding[1])==2
    if data_format is None:
        data_format=K.image_data_format()
    if data_format not in {'channels_first','channels_last'}:
        raise ValueError('unknow data_format'+str(data_format))
    if data_format=='channels_first':
        pattern=[[0,0][0,0],list(padding[0]),list(padding[1])]##batch,channe,方向都不补,h对应补,w对应补
    else:
        pattern=[[0,0],list(padding[0]),list(padding[1]),[0,0]]##batch方向不补,h对应补,w对应补,channe方向不补
    return tf.pad(x,pattern,'REFLECT')#reflect镜像填充
#镜像填充做成一个Layer,自动调用call函数
class ReflectionPadding2D(Layer):
    def __init__(self,padding=(1,1),data_format=None,**kwargs):
        super(ReflectionPadding2D,self).__init__(**kwargs)
        self.data_format=conv_utils.normalize_data_format(data_format)
        if isinstance(padding,int):#输入是1个数
            self.padding=((padding,padding),(padding,padding))
        elif hasattr(padding,"__len__"):#输入是2个数,横纵两个方向
            if len(padding)!=2:
                raise ValueError('padding should have two elements.found:'+str(padding))
            height_padding=conv_utils.normalize_tuple(padding[0],2,'1st entry of padding')
            width_padding=conv_utils.normalize_tuple(padding[1],2,'2nd entry of padding')
            self.padding=(height_padding,width_padding)
        else:
            raise ValueError('`padding` should be either an int, '
                             'a tuple of 2 ints '
                             '(symmetric_height_pad, symmetric_width_pad), '
                             'or a tuple of 2 tuples of 2 ints '
                             '((top_pad, bottom_pad), (left_pad, right_pad)). '
                             'Found: ' + str(padding))
        self.input_spec=InputSpec(ndim=4)#规定输入的维度
    def compute_output_shape(self,input_shape):#这个计算根本就没有用似乎
        if self.data_format=='channels_first':
            if input_shape[2] is not None:
                rows=input_shape[2]+self.padding[0][0]+self.padding[0][1]#h+上pad+下pad
            else:
                rows=None
            if input_shape[3] is not None:
                cols=input_shape[3]+self.padding[1][0]+self.padding[1][1]#w+左pad+右pad
            else:
                cols=None
            return (input_shape[0],input_shape[1],rows,cols)#batch,channe,pad后的h,pad后的w
        elif self.data_format=='channels_last':
            if input_shape[1] is not None:
                rows=input_shape[1]+self.padding[0][0]+self.padding[0][1]
            else:
                rows=None
            if input_shape[2] is not None:
                cols=input_shape[2]+self.padding[1][0]+self.padding[1][1]
            else:
                cols=None
            return (input_shape[0],rows,cols,input_shape[3])#batch,pad后的h,pad后的w,channe,
    def call(self,inputs):
        return spatial_reflection_2d_padding(inputs,padding=self.padding,data_format=self.data_format)
    def get_config(self):
        config={'padding':self.padding,'data_format':self.data_format}
        base_config=super(ReflectionPadding2D,self).get_config()
        return dict(list(base_config.items())+list(config.items()))
#残差卷积块
def res_block(inputs,filters,kernel_size=(3,3),strides=(1,1),use_dropout=False):
    x=ReflectionPadding2D((1,1))(inputs)
    x=Conv2D(filters=filters,kernel_size=kernel_size,strides=strides)(x)
    x=BatchNormalization()(x)
    x=Activation('relu')(x)
    if use_dropout:
        x=Dropout(0.5)(x)
    x=ReflectionPadding2D((1,1))(x)
    x=Conv2D(filters=filters,kernel_size=kernel_size,strides=strides)(x)
    x=BatchNormalization()(x)
    merged=Add()([inputs,x])
    return merged   
channel_rate=64
image_shape=(256,256,3)
patch_shape=(channel_rate,channel_rate,3)
ngf=64#生成网络的filters的个数是ngf的整数倍
ndf=64#判别网路的filters的个数是ndf的整数倍
input_nc=3#输入图像channel
output_nc=3#s输出图像channel
input_shape_generator=(256,256,input_nc)#生成网络的输入shape
input_shape_discriminator=(256,256,output_nc)#判别网络的输入shaope
n_blocks_gen=9
def generator_model():
    inputs=Input(shape=image_shape)#(256,256,3)
    x=ReflectionPadding2D((3,3))(inputs)
    x=Conv2D(filters=ngf,kernel_size=(7,7),padding='valid')(x)
    x=BatchNormalization()(x)
    x=Activation('relu')(x)
    n_downsampling=2
    for i in range(n_downsampling):
        mult=2**i#1,2
        x=Conv2D(filters=ngf*mult*2,kernel_size=(3,3),strides=2,padding='same')(x)
        x=BatchNormalization()(x)
        x=Activation('relu')(x)#(64,64,256)
    mult=2**n_downsampling#4
    for i in range(n_blocks_gen):#9次残差卷积块
        x=res_block(x,ngf*mult,use_dropout=True)
    for i in range(n_downsampling):
        mult=2**(n_downsampling-i)
        x = UpSampling2D()(x)
        x = Conv2D(filters=int(ngf * mult / 2), kernel_size=(3, 3), padding='same')(x)
        x=BatchNormalization()(x)
        x=Activation('relu')(x)#(256,256,64)
    x=ReflectionPadding2D((3,3))(x)
    x=Conv2D(filters=output_nc,kernel_size=(7,7),padding='valid')(x)#(256,256,3)
    x=Activation('tanh')(x)
    outputs=Add()([x,inputs])#大残差边
    model=Model(inputs=inputs,outputs=outputs,name='generator')
    return model
def discriminator_model():
    n_layers,use_sigmoid=3,False
    inputs=Input(shape=input_shape_discriminator)#(256,256,3)
    x=Conv2D(filters=ndf,kernel_size=(4,4),strides=2,padding='same')(inputs)#(128,128,64)
    x=LeakyReLU(0.2)(x)
    for n in range(n_layers):
        ndf_mult=2**n
        x=Conv2D(filters=ndf*ndf_mult,kernel_size=(4,4),strides=2,padding='same')(x)
        x=BatchNormalization()(x)
        x=LeakyReLU(0.2)(x)#(32,32,256)
    x=Conv2D(filters=ndf*ndf_mult*2,kernel_size=(4,4),strides=1,padding='same')(x)#(16,16,512)
    x=BatchNormalization()(x)
    x=LeakyReLU(0.2)(x)
    x=Conv2D(filters=1,kernel_size=(4,4),strides=1,padding='same')(x)#(16,16,1)
    if use_sigmoid:
        x=Activation('sigmoid')(x)
    x=Flatten()(x)
    x=Dense(1024,activation='tanh')(x)
    x=Dense(1,activation='sigmoid')(x)#1
    model=Model(inputs=inputs,outputs=x,name='discriminator')
    return model
def generator_containing_discriminator_multiple_outputs(generator,discriminator):
    inputs=Input(shape=image_shape)
    generated_image=generator(inputs)
    outputs=discriminator(generated_image)
    model=Model(inputs=inputs,outputs=[generated_image,outputs])
    return model     
#loss
def preceptual_loss(y_true,y_pred):#用vgg对图像提取的特征作为对比计算损失
    vgg=VGG16(include_top=False,weights='imagenet',input_shape=image_shape)
    loss_model=Model(inputs=vgg.input,outputs=vgg.get_layer('block3_conv3').output)
    loss_model.trainable=False#vgg模型不进行训练,只是提取图像特征
    return K.mean(K.square(loss_model(y_true)-loss_model(y_pred)))
def wasserstein_loss(y_true,y_pred):
    return K.mean(y_true*y_pred)  
#save
BASE_DIR='weights/'
def save_all_weights(d,g,epoch_number,current_loss):
    now=datetime.datetime.now()
    save_dir=os.path.join(BASE_DIR,'{}{}'.format(now.month,now.day))
    if not os.path.exists(save_dir):
        os.makedirs(save_dir)
    g.save_weights(os.path.join(save_dir,'generator_{}_{}.h5'.format(epoch_number,current_loss)),True)
    d.save_weights(os.path.join(save_dir,'discriminator_{}.h5'.format(epoch_number)),True)        
#load_images, write_log
RESHAPE = (256,256)
def preprocess_image(cv_img):
    cv_img = cv_img.resize(RESHAPE)
    img = np.array(cv_img)
    img = (img - 127.5) / 127.5
    return img
def load_images(path,n_images):
    if n_images<0:
        n_images=float('inf')
    A_path,B_path=os.path.join(path,'blurred'),os.path.join(path,'sharp')
    all_A_paths=[os.path.join(A_path,f) for f in os.listdir(A_path)]
    all_B_paths=[os.path.join(B_path,f) for f in os.listdir(B_path)]
    images_A,images_B=[],[]
    images_A_paths,images_B_paths=[],[]
    for path_A,path_B in zip(all_A_paths,all_B_paths):
        img_A=Image.open(path_A)
        img_B=Image.open(path_B)
        images_A.append(preprocess_image(img_A))
        images_B.append(preprocess_image(img_B))
        images_A_paths.append(path_A)
        images_B_paths.append(path_B)
        if len(images_A)>n_images-1:break
    return {
        'A': np.array(images_A),
        'A_paths': np.array(images_A_paths),
        'B': np.array(images_B),
        'B_paths': np.array(images_B_paths)
    }



def write_log(callback, names, logs, batch_no):
    """
    Util to write callback for Keras training
    """
    for name, value in zip(names, logs):
        summary = tf.Summary()
        summary_value = summary.value.add()
        summary_value.simple_value = value
        summary_value.tag = name
        callback.writer.add_summary(summary, batch_no)
        callback.writer.flush() 
#train
def train_multiple_outputs(n_images,batch_size,epoch_num,critic_updates=5):
    data=load_images('./blurred_sharp',n_images)
    y_train,x_train=data['B'],data['A']
    #评判网络
    d=discriminator_model()
    d_opt=Adam(lr=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
    d.trainable=True
    d.compile(optimizer=d_opt,loss=wasserstein_loss)
    
    #生成评判网络
    d.trainable=False
    g=generator_model()
    d_on_g=generator_containing_discriminator_multiple_outputs(g,d)
    d_on_g_opt=Adam(lr=1e-4,beta_1=0.9,beta_2=0.999,epsilon=1e-08)
    loss=[preceptual_loss,wasserstein_loss]
    loss_weights=[100,1]
    d_on_g.compile(optimizer=d_on_g_opt,loss=loss,loss_weights=loss_weights)
    d.trainable=True
    output_true_batch,output_false_batch=np.ones((batch_size,1)),-np.ones((batch_size,1))#用于loss计算
    log_path='./logs'
    tensorboard_callback=TensorBoard(log_path)
    for epoch in tqdm.tqdm(range(epoch_num)):
        permutated_indexs=np.random.permutation(x_train.shape[0])#打乱顺序
        d_losses=[]
        d_on_g_losses=[]
        for index in range(int(x_train.shape[0]/batch_size)):#为一个epoch生成数据,里面是一个batch的操作
            batch_indexes=permutated_indexs[index*batch_size:(index+1)*batch_size]
            image_blur_batch=x_train[batch_indexes]
            image_full_batch=y_train[batch_indexes]
            generated_images=g.predict(x=image_blur_batch,batch_size=batch_size)#生成batch_size个图
            for _ in range(critic_updates):#计算5次,这个有什么意义?
                d_loss_real=d.train_on_batch(image_full_batch,output_true_batch)
                d_loss_fake=d.train_on_batch(generated_images,output_false_batch)
                d_loss=0.5*np.add(d_loss_fake,d_loss_real)
                d_losses.append(d_loss)
            d.trainable=False
            d_on_g_loss=d_on_g.train_on_batch(image_blur_batch,[image_full_batch,output_true_batch])
            d_on_g_losses.append(d_on_g_loss)
            d.trainable=True
        print(np.mean(d_losses),np.mean(d_on_g_losses))#一个epoch的损失均值
        with open('log.txt','a+')as f:
            f.write('{}-{}-{}\n'.format(epoch,np.mean(d_losses),np.mean(d_on_g_losses)))
        save_all_weights(d,g,epoch,int(np.mean(d_on_g_losses)))
#@click.command()
#@click.option('--n_images', default=10, help='Number of images to load for training')
#@click.option('--batch_size', default=1, help='Size of batch')
#@click.option('--log_dir', required=True, help='Path to the log_dir for Tensorboard')

#@click.option('--epoch_num', default=1, help='Number of epochs for training')
#@click.option('--critic_updates', default=1, help='Number of discriminator training')
def train_command(n_images, batch_size,  epoch_num, critic_updates):
    return train_multiple_outputs(n_images, batch_size,  epoch_num, critic_updates)
if __name__ == '__main__':
    train_command(10,1,1,1)
        

6.参考

[1]https://github.com/RaphaelMeudec/deblur-gan

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我是小z呀

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值