SRGAN基于keras实现代码框架

class VGGNetwork:
    def append_vgg_network(self, x_in, true_X_input):
        return x #x is output of VGG
    def load_vgg_weight(self, model):
        return model
class DiscriminatorNetwork:
    def append_gan_network(self, true_X_input):
        return x
class GenerativeNetwork:
    def create_sr_model(self, ip):
        return x
    def get_generator_output(self, input_img, srgan_model):
        return self.output_func([input_img])
class SRGANNetwork:
    def build_srgan_pretrain_model(self):
        return self.srgan_model_
    def build_discriminator_pretrain_model(self):
        return self.discriminative_model_
    def build_srgan_model(self):
        return self.srgan_model_
    def pre_train_srgan(self, image_dir, nb_images=50000, nb_epochs=1, use_small_srgan=False):
        for i in range(nb_epochs):
            for x in datagen.flow_from_directory
                if iteration % 50 == 0 and iteration != 0
                    validation//print psnr
                Train only generator + vgg network
                if iteration % 1000 == 0 and iteration != 0
                    Saving model weights
    def pre_train_discriminator(self, image_dir, nb_images=50000, nb_epochs=1, batch_size=128): 
        for i in range(nb_epochs):
             for x in datagen.flow_from_directory
                 Train only discriminator
                 if iteration % 1000 == 0 and iteration != 0
                    Saving model weights
    def train_full_model(self, image_dir, nb_images=50000, nb_epochs=10):   
        for i in range(nb_epochs):
            for x in datagen.flow_from_directory
                if iteration % 50 == 0 and iteration != 0
                    validation//print psnr
                if iteration % 1000 == 0 and iteration != 0
                    Saving model weights
                Train only discriminator, disable training of srgan
                Train only generator, disable training of discriminator
if __name__ == "__main__":
    from keras.utils.visualize_util import plot

    # Path to MS COCO dataset
    coco_path = r"D:\Yue\Documents\Dataset\coco2014\train2014"

    '''
    Base Network manager for the SRGAN model

    Width / Height = 32 to reduce the memory requirement for the discriminator.

    Batch size = 1 is slower, but uses the least amount of gpu memory, and also acts as
    Instance Normalization (batch norm with 1 input image) which speeds up training slightly.
    '''

    srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
    srgan_network.build_srgan_model()
    #plot(srgan_network.srgan_model_, 'SRGAN.png', show_shapes=True)

    # Pretrain the SRGAN network
    #srgan_network.pre_train_srgan(coco_path, nb_images=80000, nb_epochs=1)

    # Pretrain the discriminator network
    #srgan_network.pre_train_discriminator(coco_path, nb_images=40000, nb_epochs=1, batch_size=16)

    # Fully train the SRGAN with VGG loss and Discriminator loss
    srgan_network.train_full_model(coco_path, nb_images=80000, nb_epochs=5)
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值