fast_style

 model.py

可能的解法

https://github.com/openai/InfoGAN/blob/master/infogan/misc/custom_ops.py

'''
heavily reference:
resnet: https://github.com/bgshih/tf_resnet_cifar/blob/master/src/model_resnet.py
chainer implement: https://github.com/yusuketomoto/chainer-fast-neuralstyle
tensorflow implement: https://github.com/OlavHN/fast-neural-style
'''
from __future__ import division

import math
#import ipdb
import tensorflow as tf
#from tensorflow.python import control_flow_ops
import numpy as np

#import model_utils as mu

def conv2d(x, n_in, n_out, k, s, p='SAME', bias=False, scope='conv'):
  with tf.variable_scope(scope):
    kernel = tf.Variable(
      tf.truncated_normal([k, k, n_in, n_out],
      stddev=math.sqrt(2/(k*k*n_in))),
      name='weight')
    tf.add_to_collection('weights', kernel)
    conv = tf.nn.conv2d(x, kernel, [1,s,s,1], padding=p)
    if bias:
      bias = tf.get_variable('bias', [n_out], initializer=tf.constant_initializer(0.0))
      tf.add_to_collection('biases', bias)
      conv = tf.nn.bias_add(conv, bias)
  return conv

def batch_norm(x, n_out, phase_train, scope='bn', affine=True):
  """
  Batch normalization on convolutional maps.
  Args:
    x: Tensor, 4D BHWD input maps
    n_out: integer, depth of input maps
    phase_train: boolean tf.Variable, true indicates training phase
    scope: string, variable scope
    affine: whether to affine-transform outputs
  Return:
    normed: batch-normalized maps
  """
  with tf.variable_scope(scope):
    beta = tf.Variable(tf.constant(0.0, shape=[n_out]),
      name='beta', trainable=True)
    gamma = tf.Variable(tf.constant(1.0, shape=[n_out]),
      name='gamma', trainable=affine)
    tf.add_to_collection('biases', beta)
    tf.add_to_collection('weights', gamma)

    batch_mean, batch_var = tf.nn.moments(x, [0,1,2], name='moments')
    ema = tf.train.ExponentialMovingAverage(decay=0.99)

    def mean_var_with_update():
      ema_apply_op = ema.apply([batch_mean, batch_var])
      with tf.control_dependencies([ema_apply_op]):
        return tf.identity(batch_mean), tf.identity(batch_var)
    mean, var = tf.cond(phase_train,
      mean_var_with_update,
      lambda: (ema.average(batch_mean), ema.average(batch_var)))
    
    normed = tf.nn.batch_norm_with_global_normalization(x, mean, var, 
      beta, gamma, 1e-3, affine)
    
  return normed

def residual_block(x, n_in, n_out, phase_train, scope='res_block'):
  with tf.variable_scope(scope):
    y = conv2d(x, n_in, n_out, 3, 1, 'SAME', False, scope='conv_1')
    shortcut = tf.identity(x, name='shortcut')
    y = batch_norm(y, n_out, phase_train, scope='bn_1')
    y = tf.nn.relu(y, name='relu_1')
    y = conv2d(y, n_out, n_out, 3, 1, 'SAME', True, scope='conv_2')
    y = batch_norm(y, n_out, phase_train, scope='bn_2')
    y = y + shortcut
    #y = tf.nn.relu(y, name='relu_2') #for best result
  return y

def conv2d_transpose(x, input_filters, output_filters, kernel, strides, padding='SAME'):
    with tf.variable_scope('conv_transpose') as scope:

        shape = [kernel, kernel, output_filters, input_filters]
        weight = tf.Variable(tf.truncated_normal(shape, stddev=0.1), name='weight')

        batch_size = tf.shape(x)[0]
        height = tf.shape(x)[1] * strides
        width = tf.shape(x)[2] * strides
        output_shape = tf.pack([batch_size, height, width, output_filters])
        
        convolved = tf.nn.conv2d_transpose(x, weight, output_shape, strides=[1, strides, strides, 1], padding=padding, name='conv_transpose')

        return convolved
  
def net(image, phase_train):
  with tf.variable_scope('convLay1'):
    y = conv2d(image, 3, 32, 9, 1, 'SAME', False, scope='conv_init')
    y = batch_norm(y, 32, phase_train, scope='bn_init')
    y = tf.nn.relu(y, name='relu_init')
  with tf.variable_scope('convLay2'):
    y = conv2d(y, 32, 64, 4, 2, 'SAME', False, scope='conv_init')
    y = batch_norm(y, 64, phase_train, scope='bn_init')
    y = tf.nn.relu(y, name='relu_init')
  with tf.variable_scope('convLay3'):
    y = conv2d(y, 64, 128, 4, 2, 'SAME', False, scope='conv_init')
    y = batch_norm(y, 128, phase_train, scope='bn_init')
    y = tf.nn.relu(y, name='relu_init')
  with tf.variable_scope('residualLay1'):
    y = residual_block(y, 128, 128, phase_train)
  with tf.variable_scope('residualLay2'):
    y = residual_block(y, 128, 128, phase_train)
  with tf.variable_scope('residualLay3'):
    y = residual_block(y, 128, 128, phase_train)
  with tf.variable_scope('residualLay4'):
    y = residual_block(y, 128, 128, phase_train)
  with tf.variable_scope('residualLay5'):
    y = residual_block(y, 128, 128, phase_train)
  with tf.variable_scope('deconvLay1'):
    y = conv2d_transpose(y, 128, 64, 4, 2)
    y = batch_norm(y, 64, phase_train, scope='bn_init')
    y = tf.nn.relu(y, name='relu_init')
  with tf.variable_scope('deconvLay2'):
    y = conv2d_transpose(y, 64, 32, 4, 2)
    y = batch_norm(y, 32, phase_train, scope='bn_init')
    y = tf.nn.relu(y, name='relu_init')
  with tf.variable_scope('deconvLay3'):
    y = conv2d_transpose(y, 32, 3, 9, 1)
    y = (tf.nn.tanh(y)+1)*127.5
    
  return y

 

fast_neural_style.py

from scipy import misc
import os
import time
import tensorflow as tf
import vgg
import model
import reader

tf.app.flags.DEFINE_integer("CONTENT_WEIGHT", 5e0, "Weight for content features loss")
tf.app.flags.DEFINE_integer("STYLE_WEIGHT", 1e2, "Weight for style features loss")
tf.app.flags.DEFINE_integer("TV_WEIGHT", 1e-5, "Weight for total variation loss")
tf.app.flags.DEFINE_string("VGG_PATH", "imagenet-vgg-verydeep-19.mat",
        "Path to vgg model weights")
tf.app.flags.DEFINE_string("MODEL_PATH", "models", "Path to read/write trained models")
tf.app.flags.DEFINE_string("TRAIN_IMAGES_PATH", "train2014", "Path to training images")
tf.app.flags.DEFINE_string("CONTENT_LAYERS", "relu4_2",
        "Which VGG layer to extract content loss from")
tf.app.flags.DEFINE_string("STYLE_LAYERS", "relu1_1,relu2_1,relu3_1,relu4_1,relu5_1",
        "Which layers to extract style from")
tf.app.flags.DEFINE_string("SUMMARY_PATH", "tensorboard", "Path to store Tensorboard summaries")
tf.app.flags.DEFINE_string("STYLE_IMAGES", "style.png", "Styles to train")
tf.app.flags.DEFINE_float("STYLE_SCALE", 1.0, "Scale styles. Higher extracts smaller features")
tf.app.flags.DEFINE_string("CONTENT_IMAGES_PATH", None, "Path to content image(s)")
tf.app.flags.DEFINE_integer("IMAGE_SIZE", 256, "Size of output image")
tf.app.flags.DEFINE_integer("BATCH_SIZE", 1, "Number of concurrent images to train on")

FLAGS = tf.app.flags.FLAGS

def total_variation_loss(layer):
    shape = tf.shape(layer)
    height = shape[1]
    width = shape[2]
    y = tf.slice(layer, [0,0,0,0], tf.pack([-1,height-1,-1,-1])) - tf.slice(layer, [0,1,0,0], [-1,-1,-1,-1])
    x = tf.slice(layer, [0,0,0,0], tf.pack([-1,-1,width-1,-1])) - tf.slice(layer, [0,0,1,0], [-1,-1,-1,-1])
    return tf.nn.l2_loss(x) / tf.to_float(tf.size(x)) + tf.nn.l2_loss(y) / tf.to_float(tf.size(y))

# TODO: Figure out grams and batch sizes! Doesn't make sense ..
def gram(layer):
    shape = tf.shape(layer)
    num_images = shape[0]
    num_filters = shape[3]
    size = tf.size(layer)
    filters = tf.reshape(layer, tf.pack([num_images, -1, num_filters]))
    grams = tf.batch_matmul(filters, filters, adj_x=True) / tf.to_float(size / FLAGS.BATCH_SIZE)

    return grams

def get_style_features(style_paths, style_layers):
    with tf.Graph().as_default() as g:
        size = int(round(FLAGS.IMAGE_SIZE * FLAGS.STYLE_SCALE))
        images = tf.pack([reader.get_image(path, size) for path in style_paths])
        net, _ = vgg.net(FLAGS.VGG_PATH, images)
        features = []
        for layer in style_layers:
            features.append(gram(net[layer]))

        with tf.Session() as sess:
            return sess.run(features)

def main(argv=None):
    phase_train = tf.placeholder(tf.bool, name='phase_train')

    if FLAGS.CONTENT_IMAGES_PATH:
        content_images = reader.image(
                FLAGS.BATCH_SIZE,
                FLAGS.IMAGE_SIZE,
                FLAGS.CONTENT_IMAGES_PATH,
                epochs=1,
                shuffle=False,
                crop=False)
        generated_images = model.net(content_images / 255., phase_train)

        output_format = tf.saturate_cast(generated_images + reader.mean_pixel, tf.uint8)
        with tf.Session() as sess:
            file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH)
            if not file:
                print('Could not find trained model in {}'.format(FLAGS.MODEL_PATH))
                return
            print('Using model from {}'.format(file))
            saver = tf.train.Saver()
            saver.restore(sess, file)
            sess.run(tf.initialize_local_variables())
            coord = tf.train.Coordinator()
            threads = tf.train.start_queue_runners(coord=coord)
            i = 0
            start_time = time.time()
            try:
                while not coord.should_stop():
                    print(i)
                    images_t = sess.run(output_format, {phase_train.name: False})
                    elapsed = time.time() - start_time
                    start_time = time.time()
                    print('Time for one batch: {}'.format(elapsed))

                    for raw_image in images_t:
                        i += 1
                        misc.imsave('out{0:04d}.png'.format(i), raw_image)
            except tf.errors.OutOfRangeError:
                print('Done training -- epoch limit reached')
            finally:
                coord.request_stop()

            coord.join(threads)
        return

    if not os.path.exists(FLAGS.MODEL_PATH):
        os.makedirs(FLAGS.MODEL_PATH)

    style_paths = FLAGS.STYLE_IMAGES.split(',')
    style_layers = FLAGS.STYLE_LAYERS.split(',')
    content_layers = FLAGS.CONTENT_LAYERS.split(',')

    style_features_t = get_style_features(style_paths, style_layers)

    images = reader.image(FLAGS.BATCH_SIZE, FLAGS.IMAGE_SIZE, FLAGS.TRAIN_IMAGES_PATH)
    generated = model.net(images / 255., phase_train)

    net, _ = vgg.net(FLAGS.VGG_PATH, tf.concat(0, [generated, images]))

    content_loss = 0
    for layer in content_layers:
        generated_images, content_images = tf.split(0, 2, net[layer])
        size = tf.size(generated_images)
        content_loss += tf.nn.l2_loss(generated_images - content_images) / tf.to_float(size)
    content_loss = content_loss / len(content_layers)

    style_loss = 0
    for style_gram, layer in zip(style_features_t, style_layers):
        generated_images, _ = tf.split(0, 2, net[layer])
        size = tf.size(generated_images)
        for style_image in style_gram:
            style_loss += tf.nn.l2_loss(tf.reduce_sum(gram(generated_images) - style_image, 0)) / tf.to_float(size)
    style_loss = style_loss / len(style_layers)

    loss = FLAGS.STYLE_WEIGHT * style_loss + FLAGS.CONTENT_WEIGHT * content_loss + FLAGS.TV_WEIGHT * total_variation_loss(generated)

    global_step = tf.Variable(0, name="global_step", trainable=False)
    train_op = tf.train.AdamOptimizer(1e-3).minimize(loss, global_step=global_step)

    output_format = tf.saturate_cast(tf.concat(0, [generated, images]) + reader.mean_pixel, tf.uint8)

    with tf.Session() as sess:
        saver = tf.train.Saver(tf.all_variables())
        file = tf.train.latest_checkpoint(FLAGS.MODEL_PATH)
        if file:
            print('Restoring model from {}'.format(file))
            saver.restore(sess, file)
            sess.run(tf.initialize_local_variables())
        else:
            print('New model initilized')
            sess.run(tf.initialize_all_variables())
            sess.run(initialize_local_variables())

        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        start_time = time.time()
        try:
            while not coord.should_stop():
                _, loss_t, step = sess.run([train_op, loss, global_step])
                elapsed_time = time.time() - start_time
                start_time = time.time()
                print(step, loss_t, elapsed_time)
                if step % 100 == 0:
                    print(step, loss_t, elapsed_time)
                    output_t = sess.run(output_format, {phase_train.name: True})
                    for i, raw_image in enumerate(output_t):
                        misc.imsave('out{}.png'.format(i), raw_image)
                        print('Save image.')
                if step % 10000 == 0:
                    saver.save(sess, FLAGS.MODEL_PATH + '/fast-style-model', global_step=step)
                    print ('Save model.')
        except tf.errors.OutOfRangeError:
            print('Done training -- epoch limit reached')
        finally:
            coord.request_stop()

        coord.join(threads)

if __name__ == '__main__':
    tf.app.run()

python3 fast-neural-style.py --TRAIN_IMAGES_PATH coco_img_path --STYLE_IMAGES style.png --BATCH_SIZE 4

python3 fast-neural-style.py --CONTENT_IMAGES path_to_images_to_transform

转载于:https://my.oschina.net/hounLeft/blog/735816

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值