Tensorflow高层封装简介

目前比较主流的 TensorFlow 高层封装主要有 TensorFlow-Slim 、 TFLearn 、Keras 和 Estimator 。
Tensor Flow-Slim 是 Google 官方给出的相对较早的 TensorFlow 高层封装, Goog le 通过TensorFlow-Slim 开源了 一些己经训练好的图像分析的模型,所以目前在图像识别问题中TensorFlow-Slim 仍被较多地使用。

用TensorFlow-Slim 在 MNIST 数据集上实现 LeNet-5 模型。

import tensorflow as tf
import tensorflow.contrib.slim as slim
import numpy as np

from tensorflow.examples.tutorials.mnist import input_data

# 通过TensorFlow-Slim来定义LeNet-5的网络结构。
def lenet5(inputs):
    inputs = tf.reshape(inputs, [-1, 28, 28, 1])
    net = slim.conv2d(inputs, 32, [5, 5], padding='SAME', scope='layer1-conv')#卷积
    net = slim.max_pool2d(net, 2, stride=2, scope='layer2-max-pool')#池化
    net = slim.conv2d(net, 64, [5, 5], padding='SAME', scope='layer3-conv')
    net = slim.max_pool2d(net, 2, stride=2, scope='layer4-max-pool')
    net = slim.flatten(net, scope='flatten')#拉成一维
    net = slim.fully_connected(net, 500, scope='layer5')#全连接
    net = slim.fully_connected(net, 10, scope='output')
    return net
def train(mnist):
    x = tf.placeholder(tf.float32, [None, 784], name='x-input')#输入
    y_ = tf.placeholder(tf.float32, [None, 10], name='y-input')#正确标签
    y = lenet5(x)#前向传播

    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y, labels=tf.argmax(y_, 1))#交叉熵损失函数
    loss = tf.reduce_mean(cross_entropy)

    train_op = tf.train.GradientDescentOptimizer(0.01).minimize(loss)#训练方法
    with tf.Session() as sess:
        tf.global_variables_initializer().run()
        for i in range(3000):
            xs, ys = mnist.train.next_batch(100)
            _, loss_value = sess.run([train_op, loss], feed_dict={x: xs, y_: ys})

            if i % 1000 == 0:
                print("After %d training step(s), loss on training batch is %g." % (i, loss_value))
def main(argv=None):
    mnist = input_data.read_data_sets("../../datasets/MNIST_data", one_hot=True)
    train(mnist)

if __name__ == '__main__':
    main()

TensorFlow-Slim 主要的作用是使模型定义更加简洁, 基本上每层网络可以通过一句话来实现。除了对单层网络结构, TensorFlow-Slim 还对数据预处理、损失函数、学习过程、测试过程等都提供了高层封装。 TensorFlow-Slim 最特别的一个地方是它对一些标准的神经网络模型进行了封装,比如 VGG 、 Inception 以及 ResNet ,而 且 Google 开源的训练好的图像分类模型基本都是通过TensorFlow-Slim 实现的。

与 TensorFlow-Slim 相比, TFLeam 是一个更加简洁的 TensorFlow 高层封装。通过TFLeam 可以更加容易地完成模型定义、模型训练以及模型评测的全过程。 TFLeam 没有集成在 TensorFlow 的安装包中,故需要单独安装 。

通过以下命令就可以安装 TFLeam :

pip install tflearn
#pip3 install tflearn

用 TFLeam 在 MNIST 数据集上实现 LeNet-5 模型。

import tflearn
from tflearn.layers.core import input_data, dropout, fully_connected
from tflearn.layers.conv import conv_2d, max_pool_2d
from tflearn.layers.estimator import regression
 
import tflearn.datasets.mnist as mnist
#读取数据
trainX, trainY, testX, testY = mnist.load_data(
    data_dir="../../datasets/MNIST_data", one_hot=True)
# 将图像数据resize成卷积卷积神经网络输入的格式。
trainX = trainX.reshape([-1, 28, 28, 1])
testX = testX.reshape([-1, 28, 28, 1])
 
# 构建神经网络。
net = input_data(shape=[None, 28, 28, 1], name='input')
net = conv_2d(net, 32, 5, activation='relu')
net = max_pool_2d(net, 2)
net = conv_2d(net, 64, 5, activation='relu')
net = max_pool_2d(net, 2)
net = fully_connected(net, 500, activation='relu')
net = fully_connected(net, 10, activation='softmax')
# 定义学习任务。指定优化器为sgd,学习率为0.01,损失函数为交叉熵。
net = regression(net, optimizer='sgd', learning_rate=0.01,
                 loss='categorical_crossentropy')
# 通过定义的网络结构训练模型,并在指定的验证数据上验证模型的效果。
model = tflearn.DNN(net, tensorboard_verbose=0)
model.fit(trainX, trainY, n_epoch=10,
          validation_set=([testX, testY]),
          show_metric=True)

使用 TFLeam 训练神经网络的流程:先定义神经网络的结构 ,再使用训 练数据来训练、模型。与原生态 TensorFlow 不同的地方在于, TFLeam不仅使神经网络结构定义更加简洁,还将模型训练的过程也进行了封装。另外,在定义神经网络的前向传播过程之后 , TFLeam 可以通过 regression 函数来指定损失函数和优化方法 。更方便的是,不仅 TFLeam 能很好地封装模型定义, tfleam.DNN 也能很好地封装模型训练的过程 。 通过fit函数可以指定训练中使用的数据和训练的轮数 。

Tensorflow高层封装中的keras和Estimator可以参照下列链接:

Tensorflow高层封装Keras:https://blog.csdn.net/qq_36289191/article/details/83830878

Tensorflow高层封装Estimator-DNNClassifier:https://blog.csdn.net/qq_36289191/article/details/83828873

Tensorflow高层封装Estimator-自定义模型:https://blog.csdn.net/qq_36289191/article/details/83829446

Tensorflow高层封装Estimaor用数据集作输入:https://blog.csdn.net/qq_36289191/article/details/83829980

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值