[AI教程]TensorFlow入门:使用TF-slim的模型做图像分类

1.简介

slim是TensorFlow的一个轻量级库,它基于TensorFlow实现了高层封装,将网络、loss、正则化等概念有调理的组织起来,而不是像原生tensorflow底层接口编程那样,到处充满了超参、网络定义、训练循环等。
例如,定义一个卷积:

with tf.name_scope('conv_a') as scope:
  kernel = tf.Variable(tf.truncated_normal([5, 5, 32, 64], dtype=tf.float32,
                                           stddev=1e-1), name='weights')
  conv = tf.nn.conv2d(input, kernel, [1, 1, 1, 1], padding='SAME')
  biases = tf.Variable(tf.constant(0.0, shape=[64], dtype=tf.float32),
                       trainable=True, name='biases')
  bias = tf.nn.bias_add(conv, biases)
  conv1 = tf.nn.relu(bias, name=scope)

可以看到,其中包括了命名空间,权重变量,偏置值变量,激活函数,网络等等重要概念,如果要修改的话会相当麻烦。而使用slim完成同样的卷积,只需要一行代码:

net = slim.conv2d(input, 64, [5, 5], scope='conv_a')

除了通过消除模板代码,允许用户更紧凑地定义模型之外,slim还封装了计算机视觉的几个常见模型(AlexNet,VGGNet,GoogLeNet,ResNet),对于普通用户可以直接当做黑盒来调用,对于有研究需求的用户也可以基于封装以各种方式进行修改和扩展,省去搭建模型的时间。
如果想要学习slim,这篇博客有详细的说明:TensorFlow - TF-Slim 封装模块

2.文件说明

本篇文章提供的几个文件:
在这里插入图片描述

create_tfrecord.py 定义了操作tfrecords文件的一些接口
train_model.py 训练模型
predict_test.py 测试模型
slim TF-Slim的拷贝
test_image 存放测试图片
dataset 存放数据集,它的结构如下:
在这里插入图片描述
train中是训练集,val是验证集,标签分别保存在相应的txt文件中
文件下载:Slim模型分类

3.训练过程

3.1数据预处理

TensorFlow的训练过程就是数据在网络中流动的过程,官方提供了三种数据读取方式,分别是:

  1. Feeding。通过Python直接读入数据
  2. Reading from files。从文件读取数据
  3. Preloaded data。将数据以constant或者variable的方式直接存储在运算图中

在数据量较大的情况下,官方推荐第二种标准的TensorFlow格式(Standard TensorFlow format)存储数据,文件名后缀为tfrecords。本文提供的create_tfrecord.py中提供了几个重要的函数,对于一般的图像分类问题可以直接使用。
本篇文章以VGG16举例。VGG16模型要求数据大小为224x224,设置create_tfrecord.py的参数运行可以直接得到train224.tfrecords和val224.tfrecords。

if __name__ == '__main__':
    # 参数设置
    resize_height = 224  # 指定存储图片高度
    resize_width = 224  # 指定存储图片宽度
    shuffle=True
    log=5 #打印信息的间隔
    # 产生train.record文件
    image_dir='dataset/train'
    train_labels = 'dataset/train.txt'  # 图片路径
    train_record_output = 'dataset/record/train224.tfrecords'
    create_records(image_dir,train_labels, train_record_output, resize_height, resize_width,shuffle,log)
    train_nums=get_example_nums(train_record_output)
    print("save train example nums={}".format
  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值