model.py
可能的解法
https://github.com/openai/InfoGAN/blob/master/infogan/misc/custom_ops.py
'''
heavily reference:
resnet: https://github.com/bgshih/tf_resnet_cifar/blob/master/src/model_resnet.py
chainer implement: https://github.com/yusuketomoto/chainer-fast-neuralstyle
tensorflow implement: https://github.com/OlavHN/fast-neural-style
'''
from __future__ import division
import math
#import ipdb
import tensorflow as tf
#from tensorflow.python import control_flow_ops
import numpy as np
#import model_utils as mu
def conv2d(x, n_in, n_out, k, s, p='SAME', bias=False, scope='conv'):
with tf.variable_scope(scope):
kernel = tf.Variable(
tf.truncated_normal([k, k, n_in, n_out],
stddev=math.sqrt(2/(k*k*n_in))),
name='weight')
tf.add_to_collection('weights', kernel)
conv = tf.nn.conv2d(x, kernel, [1,s,s,1], padding=p)
if bias:
bias = tf.get_variable('bias', [n_out], initializer=tf.constant_initializer(0.0))
tf.add_to_collection('biases', bias)
conv = tf.nn.bias_add(conv, bias)
return conv
def batch_norm(x, n_out, phase_train, scope='bn', affine=True):
"""
Batch normalization on convolutional maps.
Args:
x: Tensor, 4D BHWD input maps
n_out: integer, depth of input maps
phase_train: boolean tf.Variable, true indicates training phase
scope: string, variable scope
affine: whether to affine-transform outputs
Return:
normed: batch-normalized maps
"""
with tf.variable_scope(scope):
beta = tf.Variable(tf.constant(0.0, shape=[n_out]),
name='beta', trainable=True)
gamma = tf.Variable(tf.constant(1.0, shape=[n_out]),
name='gamma', trainable=affine)
tf.add_to_collection('biases', beta)
tf.add_to_collection('weights', gamma)
batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
ema = tf.train.ExponentialMovingAverage(decay=0.99)
def mean_var_with_update():
ema_apply_op = ema.apply([batch_mean, batch_var])
with tf.control_dependencies([ema_apply_op]):
return tf.identity(batch_mean), tf.identity(batch_var)
mean, var = tf.cond(phase_train,
mean_var_with_update,
lambda: (ema.average(batch_mean), ema.average(batch_var)))
normed = tf.nn.batch_norm_with_global_normalization(x, mean, var,
beta, gamma, 1e-3, affine)
return normed
def residual_block(x, n_in, n_out, phase_train, scope='res_block'):
with tf.variable_scope(scope):
y = conv2d(x, n_in, n_out, 3, 1, 'SAME', False, scope='conv_1')
shortcut = tf.identity(x, name='shortcut')
y = batch_norm(y, n_out, phase_train, scope='bn_1')
y = tf.nn.relu(y, name='relu_1')
y = conv2d(y, n_out, n_out, 3, 1, 'SAME', True, scope='conv_2')
y = batch_norm(y, n_out, phase_train, scope='bn_2')
y = y + shortcut
#y = tf.nn.relu(y, name='relu_2') #for best result
return y
def conv2d_transpose(x, input_filters, output_filters, kernel, strides, padding='SAME'):
with tf.variable_scope('conv_transpose') as scope:
shape = [kernel, kernel, output_filters, input_filters]
weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')
batch_size = tf.shape(x)[0]
height = tf.shape(x)[1] * strides
width = tf.shape(x)[2] * strides
output_shape = tf.pack([batch_size, height, width, output_filters])
convolved = tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], padding=padding, name='conv_transpose')
return convolved
def net(image, phase_train):
with tf.variable_scope('convLay1'):
y = conv2d(image, 3, 32, 9, 1, 'SAME', False, scope='conv_init')
y = batch_norm(y, 32, phase_train, scope='bn_init')
y = tf.nn.relu(y, name='relu_init')
with tf.variable_scope('convLay2'):
y = conv2d(y, 32, 64, 4, 2, 'SAME', False, scope='conv_init')
y = batch_norm(y, 64, phase_train, scope='bn_init')
y = tf.nn.relu(y, name='relu_init')
with tf.variable_scope('convLay3'):
y = conv2d(y, 64, 128, 4, 2, 'SAME', False, scope='conv_init')
y = batch_norm(y, 128, phase_train, scope='bn_init')
y = tf.nn.relu(y, name='relu_init')
with tf.variable_scope('residualLay1'):
y = residual_block(y, 128, 128, phase_train)
with tf.variable_scope('residualLay2'):
y = residual_block(y, 128, 128, phase_train)
with tf.variable_scope('residualLay3'):
y = residual_block(y, 128, 128, phase_train)
with tf.variable_scope('residualLay4'):
y = residual_block(y, 128, 128, phase_train)
with tf.variable_scope('residualLay5'):
y = residual_block(y, 128, 128, phase_train)
with tf.variable_scope('deconvLay1'):
y = conv2d_transpose(y, 128, 64, 4, 2)
y = batch_norm(y, 64, phase_train, scope='bn_init')
y = tf.nn.relu(y, name='relu_init')
with tf.variable_scope('deconvLay2'):
y = conv2d_transpose(y, 64, 32, 4, 2)
y = batch_norm(y, 32, phase_train, scope='bn_init')
y = tf.nn.relu(y, name='relu_init')
with tf.variable_scope('deconvLay3'):
y = conv2d_transpose(y, 32, 3, 9, 1)
y = (tf.nn.tanh(y)+1)*127.5
return y
fast_neural_style.py
from scipy import misc
import os
import time
import tensorflow as tf
import vgg
import model
import reader
tf.app.flags.DEFINE_integer("CONTENT_WEIGHT", 5e0, "Weight for content features loss")
tf.app.flags.DEFINE_integer("STYLE_WEIGHT", 1e2, "Weight for style features loss")
tf.app.flags.DEFINE_integer("TV_WEIGHT", 1e-5, "Weight for total variation loss")
tf.app.flags.DEFINE_string("VGG_PATH", "imagenet-vgg-verydeep-19.mat",
"Path to vgg model weights")
tf.app.flags.DEFINE_string("MODEL_PATH", "models", "Path to read/write trained models")
tf.app.flags.DEFINE_string("TRAIN_IMAGES_PATH", "train2014", "Path to training images")
tf.app.flags.DEFINE_string("CONTENT_LAYERS", "relu4_2",
"Which VGG layer to extract content loss from")
tf.app.flags.DEFINE_string("STYLE_LAYERS", "relu1_1,relu2_1,relu3_1,relu4_1,relu5_1",
"Which layers to extract style from")
tf.app.flags.DEFINE_string("SUMMARY_PATH", "tensorboard", "Path to store Tensorboard summaries")
tf.app.flags.DEFINE_string("STYLE_IMAGES", "style.png", "Styles to train")
tf.app.flags.DEFINE_float("STYLE_SCALE", 1.0, "Scale styles. Higher extracts smaller features")
tf.app.flags.DEFINE_string("CONTENT_IMAGES_PATH", None, "Path to content image(s)")
tf.app.flags.DEFINE_integer("IMAGE_SIZE", 256, "Size of output image")
tf.app.flags.DEFINE_integer("BATCH_SIZE", 1, "Number of concurrent images to train on")
FLAGS = tf.app.flags.FLAGS
def total_variation_loss(layer):
shape = tf.shape(layer)
height = shape[1]
width = shape[2]
y = tf.slice(layer, [0,0,0,0], tf.pack([-1,height-1,-1,-1])) - tf.slice(layer, [0,1,0,0], [-1,-1,-1,-1])
x = tf.slice(layer, [0,0,0,0], tf.pack([-1,-1,width-1,-1])) - tf.slice(layer, [0,0,1,0], [-1,-1,-1,-1])
return tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))
# TODO: Figure out grams and batch sizes! Doesn't make sense ..
def gram(layer):
shape = tf.shape(layer)
num_images = shape[0]
num_filters = shape[3]
size = tf.size(layer)
filters = tf.reshape(layer, tf.pack([num_images, -1, num_filters]))
grams = tf.batch_matmul(filters, filters, adj_x=True) / tf.to_float(size / FLAGS.BATCH_SIZE)
return grams
def get_style_features(style_paths, style_layers):
with tf.Graph().as_default() as g:
size = int(round(FLAGS.IMAGE_SIZE * FLAGS.STYLE_SCALE))
images = tf.pack([reader.get_image(path, size) for path in style_paths])
net, _ = vgg.net(FLAGS.VGG_PATH, images)
features = []
for layer in style_layers:
features.append(gram(net[layer]))
with tf.Session() as sess:
return sess.run(features)
def main(argv=None):
phase_train = tf.placeholder(tf.bool, name='phase_train')
if FLAGS.CONTENT_IMAGES_PATH:
content_images = reader.image(
FLAGS.BATCH_SIZE,
FLAGS.IMAGE_SIZE,
FLAGS.CONTENT_IMAGES_PATH,
epochs=1,
shuffle=False,
crop=False)
generated_images = model.net(content_images / 255., phase_train)
output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8)
with tf.Session() as sess:
file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH)
if not file:
print('Could not find trained model in {}'.format(FLAGS.MODEL_PATH))
return
print('Using model from {}'.format(file))
saver = tf.train.Saver()
saver.restore(sess, file)
sess.run(tf.initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
i = 0
start_time = time.time()
try:
while not coord.should_stop():
print(i)
images_t = sess.run(output_format, {phase_train.name: False})
elapsed = time.time() - start_time
start_time = time.time()
print('Time for one batch: {}'.format(elapsed))
for raw_image in images_t:
i += 1
misc.imsave('out{0:04d}.png'.format(i), raw_image)
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
return
if not os.path.exists(FLAGS.MODEL_PATH):
os.makedirs(FLAGS.MODEL_PATH)
style_paths = FLAGS.STYLE_IMAGES.split(',')
style_layers = FLAGS.STYLE_LAYERS.split(',')
content_layers = FLAGS.CONTENT_LAYERS.split(',')
style_features_t = get_style_features(style_paths, style_layers)
images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.TRAIN_IMAGES_PATH)
generated = model.net(images / 255., phase_train)
net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat(0, [generated, images]))
content_loss = 0
for layer in content_layers:
generated_images, content_images = tf.split(0, 2, net[layer])
size = tf.size(generated_images)
content_loss += tf.nn.l2_loss(generated_images - content_images) / tf.to_float(size)
content_loss = content_loss / len(content_layers)
style_loss = 0
for style_gram, layer in zip(style_features_t, style_layers):
generated_images, _ = tf.split(0, 2, net[layer])
size = tf.size(generated_images)
for style_image in style_gram:
style_loss += tf.nn.l2_loss(tf.reduce_sum(gram(generated_images) - style_image, 0)) / tf.to_float(size)
style_loss = style_loss / len(style_layers)
loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * total_variation_loss(generated)
global_step = tf.Variable(0, name="global_step", trainable=False)
train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step)
output_format = tf.saturate_cast(tf.concat(0, [generated, images]) + reader.mean_pixel, tf.uint8)
with tf.Session() as sess:
saver = tf.train.Saver(tf.all_variables())
file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH)
if file:
print('Restoring model from {}'.format(file))
saver.restore(sess, file)
sess.run(tf.initialize_local_variables())
else:
print('New model initilized')
sess.run(tf.initialize_all_variables())
sess.run(initialize_local_variables())
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
start_time = time.time()
try:
while not coord.should_stop():
_, loss_t, step = sess.run([train_op, loss, global_step])
elapsed_time = time.time() - start_time
start_time = time.time()
print(step, loss_t, elapsed_time)
if step % 100 == 0:
print(step, loss_t, elapsed_time)
output_t = sess.run(output_format, {phase_train.name: True})
for i, raw_image in enumerate(output_t):
misc.imsave('out{}.png'.format(i), raw_image)
print('Save image.')
if step % 10000 == 0:
saver.save(sess, FLAGS.MODEL_PATH + '/fast-style-model', global_step=step)
print ('Save model.')
except tf.errors.OutOfRangeError:
print('Done training -- epoch limit reached')
finally:
coord.request_stop()
coord.join(threads)
if __name__ == '__main__':
tf.app.run()
python3 fast-neural-style.py --TRAIN_IMAGES_PATH coco_img_path --STYLE_IMAGES style.png --BATCH_SIZE 4
python3 fast-neural-style.py --CONTENT_IMAGES path_to_images_to_transform