resnet_v2、resnet_v1、inception等网络简单实现及部署

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/luoyexuge/article/details/82845201

resnet_v2、resnet_v1、inception这些网络在tensorflow中封装的比较死,全部封装在slim模块下,当然一些更高级的网络暂时没看到封装在下面,比如胶囊网络、以及inceptionv4,对应的finetune模型下载地址如下:https://github.com/tensorflow/models/tree/master/research/slim,下面写的博客基本上一个搬运工,没什么技术含量,看看就可以,以resnet_v2为例子,读取tfrecords就不再累赘:

 

训练以及保存为checkpoint格式:

import tensorflow as tf
import time
import os
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import resnet_v1

# define FLAGS
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('train_iters', 200000, '')
tf.app.flags.DEFINE_integer('display_step', 100, '')
tf.app.flags.DEFINE_integer('batch_size', 8, '')
tf.app.flags.DEFINE_integer('num_threads', 2, '')
tf.app.flags.DEFINE_integer('image_classes', 20, '')
tf.app.flags.DEFINE_integer('image_crop_height', 224, '')
tf.app.flags.DEFINE_integer('image_crop_width', 224, '')
tf.app.flags.DEFINE_integer('image_channels', 3, '')
tf.app.flags.DEFINE_integer('image_height', 256, '')
tf.app.flags.DEFINE_integer('image_width', 256, '')
tf.app.flags.DEFINE_integer('image_mean', 128, '')
tf.app.flags.DEFINE_float('learning_rate', 0.0001, '')
tf.app.flags.DEFINE_float('accuracy_limit', 0.95, '')
tf.app.flags.DEFINE_float('dropout_rate', 0.5, '')
tf.app.flags.DEFINE_string("checkpointDir", "model/", "oss info")


def read_and_decode(filename_queue):
    all_paths = tf.train.match_filenames_once(filename_queue)
    input_path_queue = tf.train.string_input_producer(all_paths,shuffle = False, num_epochs = 10)
    reader = tf.TFRecordReader()
    _, example = reader.read(input_path_queue)
    features = tf.parse_single_example(example,
                                       features={
                                           'label': tf.FixedLenFeature([], tf.string),
                                           'data': tf.FixedLenFeature([], tf.string)})
    label_queue = tf.cast(tf.string_to_number(features['label']), tf.int32)

    image_queue = tf.image.resize_images(tf.image.decode_jpeg(features['data'], 3),
                                         [FLAGS.image_height, FLAGS.image_width], 0)
    image_queue = tf.subtract(tf.to_float(image_queue), FLAGS.image_mean)

    cropped_image_queue = tf.random_crop(image_queue,
                                         [FLAGS.image_crop_height, FLAGS.image_crop_width, FLAGS.image_channels])

    return cropped_image_queue, label_queue


def inputs(file, batch_size, num_epochs):
    if not num_epochs:
        num_epochs = None
    feature, label = read_and_decode(file)
    image_batch_queue, label_batch_queue = tf.train.shuffle_batch([feature, label],batch_size=FLAGS.batch_size,num_threads=10,capacity=10000,min_after_dequeue=9999)

    reshaped_image_batch_queue = tf.reshape(image_batch_queue,
                                            [-1, FLAGS.image_crop_height, FLAGS.image_crop_width, FLAGS.image_channels])
    one_hot_label_batch_queue = tf.to_float(tf.one_hot(label_batch_queue,
                                                       FLAGS.image_classes, 1, 0))
    return reshaped_image_batch_queue, one_hot_label_batch_queue



# loss function
def loss(logits, labels):
    cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
    return tf.reduce_mean(cross_entropy)


# validate function
def accuracy(logits, labels):
    arglabels_ = tf.argmax(tf.nn.softmax(logits), 1)
    arglabels = tf.argmax(labels, 1)
    error = tf.to_float(tf.equal(arglabels_, arglabels))
    return tf.reduce_mean(error)


# train task
def train_task(train_path, val_path):

    train_images, train_lables=inputs(train_path,FLAGS.batch_size,2)
    val_images, val_lables = inputs(val_path, FLAGS.batch_size,2)

    with slim.arg_scope(resnet_v1.resnet_arg_scope()):
        predictions, end_points = resnet_v1.resnet_v1_101(train_images, num_classes=20, is_training=True)
        val_predict, _ = resnet_v1.resnet_v1_101(val_images, num_classes=20, is_training=False, reuse=True)

    predictions = tf.squeeze(predictions)

    val_predict = tf.squeeze(val_predict)

    print(predictions.shape)
    print(train_lables.shape)
    train_loss = loss(predictions, train_lables)
    train_accuracy = accuracy(val_predict, val_lables)

    #opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
    opt=tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)


    with tf.control_dependencies([tf.group(*update_ops)]):
        grads = opt.compute_gradients(train_loss)
        iter = opt.apply_gradients(grads)
        # train_op = opt.minimize(train_loss)

    with tf.Session() as sess:
        sess.run(tf.local_variables_initializer())

        sess.run(tf.global_variables_initializer())

        print('started training (validate every %d runs)' % (FLAGS.display_step))
        checkpoint_path = os.path.join("/Users/zhoumeixu/Downloads", 'resnet_v1_101.ckpt')
        exclusions = []
        includes = []
        for var in slim.get_model_variables():
            if "logits" in var.op.name or "Logits" in var.op.name :
                print(var.op.name)
                exclusions.append(var.op.name)
            else:
                includes.append(var.op.name)
        print(len(includes))
        variables_to_restore = slim.get_variables_to_restore(include=includes)
        saver = tf.train.Saver(variables_to_restore)
        saver.restore(sess, checkpoint_path)

        coord = tf.train.Coordinator()
        runners = tf.train.start_queue_runners(sess=sess, coord=coord)

        for step in range(FLAGS.train_iters):
            if coord.should_stop():
                break;
            t1 = time.time()
            local_loss, _ = sess.run([train_loss, iter])
            # local_loss = sess.run(train_op)
            duration = time.time() - t1
            print('step=%d, loss=%f, duration=%fs.' % \
                  (step, local_loss, duration))
            #if step % FLAGS.display_step == 0 and step > 0:
            if step % 10== 0 and step > 0:
                local_accuracy = sess.run(train_accuracy)
                print('step=%d, accuracy=%f.' % (step, local_accuracy))
                ckp_path = os.path.join(FLAGS.checkpointDir, "model_test.ckpt")
                save_path = saver.save(sess, ckp_path,global_step=step)
                print("Model saved in file: %s" % save_path)
                if local_accuracy > FLAGS.accuracy_limit:
                    print('accuracy was higher than supposed %f' % FLAGS.accuracy_limit)
                    break
        print('all steps was done.')
        coord.request_stop()
        coord.join(runners)
        sess.close()


def main(_):
    ossPath = os.path.join("/Users/zhoumeixu/Desktop/train/", "tr_recor*")
    print(ossPath)

    train_path = ossPath

    val_path = os.path.join("/Users/zhoumeixu/Desktop/valid/", "tr_reco*")

    print(val_path)
    train_task(train_path, val_path)


if __name__ == '__main__':
    tf.app.run()

 

转化为pb格式:

import tensorflow as tf
import tensorflow.contrib.slim  as slim
from  tensorflow.contrib.slim.python.slim.nets import  resnet_v1
import  os

def  save_model():
    graph = tf.get_default_graph()
    x=tf.placeholder(tf.float32,[None,None,3],name="input")
    input=tf.expand_dims(x,0)
    with slim.arg_scope(resnet_v1.resnet_arg_scope()):
        val_predict, _ =resnet_v1.resnet_v1_101(input,num_classes=20,is_training=False,reuse=tf.AUTO_REUSE)

    val_predict = tf.squeeze(val_predict)
    pred=tf.nn.softmax(val_predict,name="output")

    exclusions = []
    includes = []
    for var in slim.get_model_variables():
        if "logits" in var.op.name or "Logits" in var.op.name:
            print(var.op.name)
            exclusions.append(var.op.name)
        else:
            includes.append(var.op.name)
    print(len(includes))
    variables_to_restore = slim.get_variables_to_restore(include=includes)

    restore_saver = tf.train.Saver(variables_to_restore)
    modelpath="/Users/zhoumeixu/Documents/python/credit-transform/model/"
    with tf.Session(graph=graph) as  sess:
        sess.run(tf.global_variables_initializer())
        latest_ckpt = tf.train.latest_checkpoint(modelpath)
        print(latest_ckpt)

        restore_saver.restore(sess, latest_ckpt)

        export_path = os.path.join(
            tf.compat.as_bytes("model"),
            tf.compat.as_bytes(str(2)))
        print('Exporting trained model to', export_path)

        builder = tf.saved_model.builder.SavedModelBuilder(export_path)

        tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
        tensor_info_y = tf.saved_model.utils.build_tensor_info(pred)

        prediction_signature = (
            tf.saved_model.signature_def_utils.build_signature_def(
                inputs={'input': tensor_info_x},
                outputs={'output': tensor_info_y},
                method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
        builder.add_meta_graph_and_variables(
            sess, [tf.saved_model.tag_constants.SERVING],
            signature_def_map={
                'prediction':
                    prediction_signature,
            }
        )

        builder.save()

if  __name__=="__main__":
    save_model()

 

java中调:

package com.alibaba.tensorflow;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;

import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;

public class ResNetV2 {
	private static SavedModelBundle bundle = null;
	static {
		String classpath = "/Users/zhoumeixu/Documents/python/credit-transform/model/2";
		bundle = TensorflowUtils.loadmodel(classpath);

	}

	public static float[][][] inputs() {
		float[][][] result = new float[224][224][3];
		for (int i = 0; i < 224; i++) {
			for (int j = 0; j < 224; j++) {
				for (int num = 0; num < 3; num++) {
					result[i][j][num] = (float) Math.random();
				}
			}
		}
		return result;
	}

	public static float[] getResult(float[][][] arr) {
		Tensor tensor = Tensor.create(arr);
		Tensor result = bundle.session().runner().feed("input", tensor).fetch("output").run().get(0);
		long[] rshape = result.shape();
		int batchSize = (int) rshape[0];

		float[] logits = (float[]) result.copyTo(new float[batchSize]);
		return logits;

	}

	public static void main(String[] args) {

		float[][][] arr = inputs();
		float[] result = getResult(arr);

		System.out.println(result.length);
		System.out.println(Arrays.toString(result));
		
		
	}

}

没什么技术含量,有一些细节注意就可以,训练网络实际更简单的可以用命令行,代码都不需要写,只需要你把数据准备好。

阅读更多
换一批

没有更多推荐了,返回首页