resnet_v2、resnet_v1、inception这些网络在tensorflow中封装的比较死,全部封装在slim模块下,当然一些更高级的网络暂时没看到封装在下面,比如胶囊网络、以及inceptionv4,对应的finetune模型下载地址如下:https://github.com/tensorflow/models/tree/master/research/slim,下面写的博客基本上一个搬运工,没什么技术含量,看看就可以,以resnet_v2为例子,读取tfrecords就不再累赘:
训练以及保存为checkpoint格式:
import tensorflow as tf
import time
import os
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.nets import resnet_v1
# define FLAGS
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer('train_iters', 200000, '')
tf.app.flags.DEFINE_integer('display_step', 100, '')
tf.app.flags.DEFINE_integer('batch_size', 8, '')
tf.app.flags.DEFINE_integer('num_threads', 2, '')
tf.app.flags.DEFINE_integer('image_classes', 20, '')
tf.app.flags.DEFINE_integer('image_crop_height', 224, '')
tf.app.flags.DEFINE_integer('image_crop_width', 224, '')
tf.app.flags.DEFINE_integer('image_channels', 3, '')
tf.app.flags.DEFINE_integer('image_height', 256, '')
tf.app.flags.DEFINE_integer('image_width', 256, '')
tf.app.flags.DEFINE_integer('image_mean', 128, '')
tf.app.flags.DEFINE_float('learning_rate', 0.0001, '')
tf.app.flags.DEFINE_float('accuracy_limit', 0.95, '')
tf.app.flags.DEFINE_float('dropout_rate', 0.5, '')
tf.app.flags.DEFINE_string("checkpointDir", "model/", "oss info")
def read_and_decode(filename_queue):
all_paths = tf.train.match_filenames_once(filename_queue)
input_path_queue = tf.train.string_input_producer(all_paths,shuffle = False, num_epochs = 10)
reader = tf.TFRecordReader()
_, example = reader.read(input_path_queue)
features = tf.parse_single_example(example,
features={
'label': tf.FixedLenFeature([], tf.string),
'data': tf.FixedLenFeature([], tf.string)})
label_queue = tf.cast(tf.string_to_number(features['label']), tf.int32)
image_queue = tf.image.resize_images(tf.image.decode_jpeg(features['data'], 3),
[FLAGS.image_height, FLAGS.image_width], 0)
image_queue = tf.subtract(tf.to_float(image_queue), FLAGS.image_mean)
cropped_image_queue = tf.random_crop(image_queue,
[FLAGS.image_crop_height, FLAGS.image_crop_width, FLAGS.image_channels])
return cropped_image_queue, label_queue
def inputs(file, batch_size, num_epochs):
if not num_epochs:
num_epochs = None
feature, label = read_and_decode(file)
image_batch_queue, label_batch_queue = tf.train.shuffle_batch([feature, label],batch_size=FLAGS.batch_size,num_threads=10,capacity=10000,min_after_dequeue=9999)
reshaped_image_batch_queue = tf.reshape(image_batch_queue,
[-1, FLAGS.image_crop_height, FLAGS.image_crop_width, FLAGS.image_channels])
one_hot_label_batch_queue = tf.to_float(tf.one_hot(label_batch_queue,
FLAGS.image_classes, 1, 0))
return reshaped_image_batch_queue, one_hot_label_batch_queue
# loss function
def loss(logits, labels):
cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits, labels=labels)
return tf.reduce_mean(cross_entropy)
# validate function
def accuracy(logits, labels):
arglabels_ = tf.argmax(tf.nn.softmax(logits), 1)
arglabels = tf.argmax(labels, 1)
error = tf.to_float(tf.equal(arglabels_, arglabels))
return tf.reduce_mean(error)
# train task
def train_task(train_path, val_path):
train_images, train_lables=inputs(train_path,FLAGS.batch_size,2)
val_images, val_lables = inputs(val_path, FLAGS.batch_size,2)
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
predictions, end_points = resnet_v1.resnet_v1_101(train_images, num_classes=20, is_training=True)
val_predict, _ = resnet_v1.resnet_v1_101(val_images, num_classes=20, is_training=False, reuse=True)
predictions = tf.squeeze(predictions)
val_predict = tf.squeeze(val_predict)
print(predictions.shape)
print(train_lables.shape)
train_loss = loss(predictions, train_lables)
train_accuracy = accuracy(val_predict, val_lables)
#opt = tf.train.AdamOptimizer(FLAGS.learning_rate)
opt=tf.train.GradientDescentOptimizer(FLAGS.learning_rate)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies([tf.group(*update_ops)]):
grads = opt.compute_gradients(train_loss)
iter = opt.apply_gradients(grads)
# train_op = opt.minimize(train_loss)
with tf.Session() as sess:
sess.run(tf.local_variables_initializer())
sess.run(tf.global_variables_initializer())
print('started training (validate every %d runs)' % (FLAGS.display_step))
checkpoint_path = os.path.join("/Users/zhoumeixu/Downloads", 'resnet_v1_101.ckpt')
exclusions = []
includes = []
for var in slim.get_model_variables():
if "logits" in var.op.name or "Logits" in var.op.name :
print(var.op.name)
exclusions.append(var.op.name)
else:
includes.append(var.op.name)
print(len(includes))
variables_to_restore = slim.get_variables_to_restore(include=includes)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, checkpoint_path)
coord = tf.train.Coordinator()
runners = tf.train.start_queue_runners(sess=sess, coord=coord)
for step in range(FLAGS.train_iters):
if coord.should_stop():
break;
t1 = time.time()
local_loss, _ = sess.run([train_loss, iter])
# local_loss = sess.run(train_op)
duration = time.time() - t1
print('step=%d, loss=%f, duration=%fs.' % \
(step, local_loss, duration))
#if step % FLAGS.display_step == 0 and step > 0:
if step % 10== 0 and step > 0:
local_accuracy = sess.run(train_accuracy)
print('step=%d, accuracy=%f.' % (step, local_accuracy))
ckp_path = os.path.join(FLAGS.checkpointDir, "model_test.ckpt")
save_path = saver.save(sess, ckp_path,global_step=step)
print("Model saved in file: %s" % save_path)
if local_accuracy > FLAGS.accuracy_limit:
print('accuracy was higher than supposed %f' % FLAGS.accuracy_limit)
break
print('all steps was done.')
coord.request_stop()
coord.join(runners)
sess.close()
def main(_):
ossPath = os.path.join("/Users/zhoumeixu/Desktop/train/", "tr_recor*")
print(ossPath)
train_path = ossPath
val_path = os.path.join("/Users/zhoumeixu/Desktop/valid/", "tr_reco*")
print(val_path)
train_task(train_path, val_path)
if __name__ == '__main__':
tf.app.run()
转化为pb格式:
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets import resnet_v1
import os
def save_model():
graph = tf.get_default_graph()
x=tf.placeholder(tf.float32,[None,None,3],name="input")
input=tf.expand_dims(x,0)
with slim.arg_scope(resnet_v1.resnet_arg_scope()):
val_predict, _ =resnet_v1.resnet_v1_101(input,num_classes=20,is_training=False,reuse=tf.AUTO_REUSE)
val_predict = tf.squeeze(val_predict)
pred=tf.nn.softmax(val_predict,name="output")
exclusions = []
includes = []
for var in slim.get_model_variables():
if "logits" in var.op.name or "Logits" in var.op.name:
print(var.op.name)
exclusions.append(var.op.name)
else:
includes.append(var.op.name)
print(len(includes))
variables_to_restore = slim.get_variables_to_restore(include=includes)
restore_saver = tf.train.Saver(variables_to_restore)
modelpath="/Users/zhoumeixu/Documents/python/credit-transform/model/"
with tf.Session(graph=graph) as sess:
sess.run(tf.global_variables_initializer())
latest_ckpt = tf.train.latest_checkpoint(modelpath)
print(latest_ckpt)
restore_saver.restore(sess, latest_ckpt)
export_path = os.path.join(
tf.compat.as_bytes("model"),
tf.compat.as_bytes(str(2)))
print('Exporting trained model to', export_path)
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(pred)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
inputs={'input': tensor_info_x},
outputs={'output': tensor_info_y},
method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
builder.add_meta_graph_and_variables(
sess, [tf.saved_model.tag_constants.SERVING],
signature_def_map={
'prediction':
prediction_signature,
}
)
builder.save()
if __name__=="__main__":
save_model()
java中调:
package com.alibaba.tensorflow;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.regex.Pattern;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Tensor;
public class ResNetV2 {
private static SavedModelBundle bundle = null;
static {
String classpath = "/Users/zhoumeixu/Documents/python/credit-transform/model/2";
bundle = TensorflowUtils.loadmodel(classpath);
}
public static float[][][] inputs() {
float[][][] result = new float[224][224][3];
for (int i = 0; i < 224; i++) {
for (int j = 0; j < 224; j++) {
for (int num = 0; num < 3; num++) {
result[i][j][num] = (float) Math.random();
}
}
}
return result;
}
public static float[] getResult(float[][][] arr) {
Tensor tensor = Tensor.create(arr);
Tensor result = bundle.session().runner().feed("input", tensor).fetch("output").run().get(0);
long[] rshape = result.shape();
int batchSize = (int) rshape[0];
float[] logits = (float[]) result.copyTo(new float[batchSize]);
return logits;
}
public static void main(String[] args) {
float[][][] arr = inputs();
float[] result = getResult(arr);
System.out.println(result.length);
System.out.println(Arrays.toString(result));
}
}
没什么技术含量,有一些细节注意就可以,训练网络实际更简单的可以用命令行,代码都不需要写,只需要你把数据准备好。