这周我的工作主要负责封装风格迁移网络对于外界的接口,经过一周的训练,我们已经得到了七种风格的网络模型,分别保存在
ckpt文件当中。首先判断用户选择的风格样式,然后调用相应的tensorflow调用相关的网络模型,将用户传入的图片经过迁移、保存,返回给用户。
主要代码如下:
def main(argv):
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
style = argv[1]
rawImg = argv[2]
genImg = argv[3]
print(style,rawImg,genImg)
tf.app.flags.DEFINE_string('loss_model', 'vgg_16', 'You can view all the support models in nets/nets_factory.py')
tf.app.flags.DEFINE_integer('image_size', 256, 'Image size to train.')
model_path = "E:\\programming\\PYworkbench\\Style-Transformer-Website\\trained_models\\"
if style=='1':
tf.app.flags.DEFINE_string("model_file", model_path+"shuimo.ckpt-done", "")
elif style=='2':
tf.app.flags.DEFINE_string("model_file", model_path + "cubist.ckpt-6000", "")
elif style =='3':
tf.app.flags.DEFINE_string("model_file", model_path + "denoised_starry.ckpt-done", "")
elif style =='4':
tf.app.flags.DEFINE_string("model_file", model_path + "feathers.ckpt-done", "")
elif style=='5':
tf.app.flags.DEFINE_string("model_file", model_path + "mosaic.ckpt-done", "")
elif style=='6':
tf.app.flags.DEFINE_string("model_file", model_path + "scream.ckpt-done", "")
elif style=='7':
tf.app.flags.DEFINE_string("model_file", model_path + "udnie.ckpt-done", "")
elif style=='8':
tf.app.flags.DEFINE_string("model_file", model_path + "wave.ckpt-done", "")
elif style == '9':
tf.app.flags.DEFINE_string("model_file", model_path + "jianzhi.ckpt-4000", "")
tf.app.flags.DEFINE_string("image_file",rawImg, "")
FLAGS = tf.app.flags.FLAGS
with open(FLAGS.image_file, 'rb') as img:
with tf.Session().as_default() as sess:
if FLAGS.image_file.lower().endswith('png'):
image = sess.run(tf.image.decode_png(img.read()))
else:
image = sess.run(tf.image.decode_jpeg(img.read()))
height = image.shape[0]
width = image.shape[1]
tf.logging.info('Image size: %dx%d' % (width, height))
with tf.Graph().as_default():
with tf.Session().as_default() as sess:
# Read image data.
image_preprocessing_fn, _ = preprocessing_factory.get_preprocessing(
FLAGS.loss_model,
is_training=False)
image = reader.get_image(FLAGS.image_file, height, width, image_preprocessing_fn)
# Add batch dimension
image = tf.expand_dims(image, 0)
generated = model.net(image, training=False)
generated = tf.cast(generated, tf.uint8)
# Remove batch dimension
generated = tf.squeeze(generated, [0])
# Restore model variables.
saver = tf.train.Saver(tf.global_variables(), write_version=tf.train.SaverDef.V1)
sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])
# Use absolute path
FLAGS.model_file = os.path.abspath(FLAGS.model_file)
saver.restore(sess, FLAGS.model_file)
# Make sure 'generated' directory exists.
generated_file = genImg
if os.path.exists('generated') is False:
os.makedirs('generated')
# Generate and write image data to file.
with open(generated_file, 'wb') as img:
start_time = time.time()
img.write(sess.run(tf.image.encode_jpeg(generated)))
if(style == '1'):
str = 'python Sky_segment_postProcessing/sky_postprocessing.py '\
+ rawImg + ' ' +generated_file
os.system(str)
end_time = time.time()
print('Elapsed time: %fs' % (end_time - start_time))
print('Done. Please check %s.' % generated_file)