class VGGNetwork:
def append_vgg_network(self, x_in, true_X_input):
return x #x is output of VGG
def load_vgg_weight(self, model):
return model
class DiscriminatorNetwork:
def append_gan_network(self, true_X_input):
return x
class GenerativeNetwork:
def create_sr_model(self, ip):
return x
def get_generator_output(self, input_img, srgan_model):
return self.output_func([input_img])
class SRGANNetwork:
def build_srgan_pretrain_model(self):
return self.srgan_model_
def build_discriminator_pretrain_model(self):
return self.discriminative_model_
def build_srgan_model(self):
return self.srgan_model_
def pre_train_srgan(self, image_dir, nb_images=50000, nb_epochs=1, use_small_srgan=False):
for i in range(nb_epochs):
for x in datagen.flow_from_directory
if iteration % 50 == 0 and iteration != 0
validation//print psnr
Train only generator + vgg network
if iteration % 1000 == 0 and iteration != 0
Saving model weights
def pre_train_discriminator(self, image_dir, nb_images=50000, nb_epochs=1, batch_size=128):
for i in range(nb_epochs):
for x in datagen.flow_from_directory
Train only discriminator
if iteration % 1000 == 0 and iteration != 0
Saving model weights
def train_full_model(self, image_dir, nb_images=50000, nb_epochs=10):
for i in range(nb_epochs):
for x in datagen.flow_from_directory
if iteration % 50 == 0 and iteration != 0
validation//print psnr
if iteration % 1000 == 0 and iteration != 0
Saving model weights
Train only discriminator, disable training of srgan
Train only generator, disable training of discriminator
if __name__ == "__main__":
from keras.utils.visualize_util import plot
# Path to MS COCO dataset
coco_path = r"D:\Yue\Documents\Dataset\coco2014\train2014"
'''
Base Network manager for the SRGAN model
Width / Height = 32 to reduce the memory requirement for the discriminator.
Batch size = 1 is slower, but uses the least amount of gpu memory, and also acts as
Instance Normalization (batch norm with 1 input image) which speeds up training slightly.
'''
srgan_network = SRGANNetwork(img_width=32, img_height=32, batch_size=1)
srgan_network.build_srgan_model()
#plot(srgan_network.srgan_model_, 'SRGAN.png', show_shapes=True)
# Pretrain the SRGAN network
#srgan_network.pre_train_srgan(coco_path, nb_images=80000, nb_epochs=1)
# Pretrain the discriminator network
#srgan_network.pre_train_discriminator(coco_path, nb_images=40000, nb_epochs=1, batch_size=16)
# Fully train the SRGAN with VGG loss and Discriminator loss
srgan_network.train_full_model(coco_path, nb_images=80000, nb_epochs=5)
SRGAN基于keras实现代码框架
最新推荐文章于 2024-09-20 20:09:43 发布