1.简介
slim是TensorFlow的一个轻量级库,它基于TensorFlow实现了高层封装,将网络、loss、正则化等概念有调理的组织起来,而不是像原生tensorflow底层接口编程那样,到处充满了超参、网络定义、训练循环等。
例如,定义一个卷积:
with tf.name_scope('conv_a') as scope:
kernel = tf.Variable(tf.truncated_normal([5, 5, 32, 64], dtype=tf.float32,
stddev=1e-1), name='weights')
conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
trainable=True, name='biases')
bias = tf.nn.bias_add(conv, biases)
conv1 = tf.nn.relu(bias, name=scope)
可以看到,其中包括了命名空间,权重变量,偏置值变量,激活函数,网络等等重要概念,如果要修改的话会相当麻烦。而使用slim完成同样的卷积,只需要一行代码:
net = slim.conv2d(input, 64, [5, 5], scope='conv_a')
除了通过消除模板代码,允许用户更紧凑地定义模型之外,slim还封装了计算机视觉的几个常见模型(AlexNet,VGGNet,GoogLeNet,ResNet),对于普通用户可以直接当做黑盒来调用,对于有研究需求的用户也可以基于封装以各种方式进行修改和扩展,省去搭建模型的时间。
如果想要学习slim,这篇博客有详细的说明:TensorFlow - TF-Slim 封装模块
2.文件说明
本篇文章提供的几个文件:
create_tfrecord.py 定义了操作tfrecords文件的一些接口
train_model.py 训练模型
predict_test.py 测试模型
slim TF-Slim的拷贝
test_image 存放测试图片
dataset 存放数据集,它的结构如下:
train中是训练集,val是验证集,标签分别保存在相应的txt文件中
文件下载:Slim模型分类
3.训练过程
3.1数据预处理
TensorFlow的训练过程就是数据在网络中流动的过程,官方提供了三种数据读取方式,分别是:
- Feeding。通过Python直接读入数据
- Reading from files。从文件读取数据
- Preloaded data。将数据以constant或者variable的方式直接存储在运算图中
在数据量较大的情况下,官方推荐第二种标准的TensorFlow格式(Standard TensorFlow format)存储数据,文件名后缀为tfrecords。本文提供的create_tfrecord.py中提供了几个重要的函数,对于一般的图像分类问题可以直接使用。
本篇文章以VGG16举例。VGG16模型要求数据大小为224x224,设置create_tfrecord.py的参数运行可以直接得到train224.tfrecords和val224.tfrecords。
if __name__ == '__main__':
# 参数设置
resize_height = 224 # 指定存储图片高度
resize_width = 224 # 指定存储图片宽度
shuffle=True
log=5 #打印信息的间隔
# 产生train.record文件
image_dir='dataset/train'
train_labels = 'dataset/train.txt' # 图片路径
train_record_output = 'dataset/record/train224.tfrecords'
create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
train_nums=get_example_nums(train_record_output)
print("save train example nums={}".format