[AI教程]TensorFlow入门:Simple Linear Model

介绍

本文演示了使用简单线性模型了解TensorFlow的基本工作流程。
数据集:MNIST数据集
工具:TensorFlow 1.9.0 + Python 3.6.3
方法:简单线性模型

1、import

import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn.metrics import confusion_matrix
# Tensorflow's version
tf.__version__

2、Load Data

# The MNIST data-set is about 12 MB 
# and will be downloaded automatically if it is not located in the given path.
from tensorflow.examples.tutorial
s.mnist import input_data
data = input_data.read_data_sets("data/MNIST/", one_hot=True)

MNIST数据集现在已经加载并由70个图像和相关联的标签(即图像的分类)组成。数据集被分成3个互斥子集。在本文中,我们只使用训练和测试集。

print("Size of:")
print("- Training-set:\t\t{}".format(len(data.train.labels)))
print("- Test-set:\t\t{}".format(len(data.test.labels)))
print("- Validation-set:\t{}".format(len(data.validation.labels)))

输出结果如下:
Size of:
– Training-set: 55000
– Test-set: 10000
– Validation-set: 5000

3、One-Hot Encoding

数据集已被加载为一个热编码。这意味着标签已经从一个单一的数字转换为一个向量,其长度等于类的数量。向量的所有元素都是零,除了第 "i"个元素是 "1"并且意味着类是 “i”。例如,测试集中前5个图像的一个热编码标签是:

data.test.labels[0:5, :]

输出结果如下:

array([[ 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.],
[ 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
[ 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])

为了进行各种比较和性能度量,我们还需要将类作为单个数字,因此我们采用最高元素的索引将One-Hot编码向量转换为单个数字。注意,“类”一词是Python中使用的关键字,因此我们需要使用“CLS”的名称。

data.test.cls = np.array([label.argmax() for label in data.test.labels])

现在我们可以看到测试集前五个图像的类。将它们与上面的一个热编码向量进行比较。例如,第一图像的类是7,它对应于一个热编码向量,其中所有元素都是零,除了具有索引7的元素。

data.test.cls[0:5]

输出结果如下:

array([7, 2, 1, 0, 4], dtype=int64)

4、Data dimensions

数据维度在下面的代码中的多个地方使用。在计算机编程中,通常最好使用变量和常量,而不是每次使用该数字时都要编码特定的数字。这意味着数字只需要在一个地方改变。

# We know that MNIST images are 28 pixels in each dimension.
img_size = 28

# Images are stored in one-dimensional arrays of this length.
img_size_flat = img_size * img_size

# Tuple with height and width of images used to reshape arrays.
img_shape = (img_size, img_size)

# Number of classes, one class for each of 10 digits.
num_classes = 10

5、Helper-function for plotting images

函数用于在3x3网格中绘制9幅图像,并在每个图像下写入真实和预测的类。

def plot_images(images, cls_true, cls_pred=None):
    assert len(images) == len(cls_true) == 9
    
    # Create figure with 3x3 sub-plots.
    fig, axes = plt.subplots(3, 3)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)

    for i, ax in enumerate(axes.flat):
        # Plot image.
        ax.imshow(images[i].reshape(img_shape), cmap='binary')

        # Show true and predicted classes.
        if cls_pred is None:
            xlabel = "True: {0}".format(cls_true[i])
        else:
            xlabel = "True: {0}, Pred: {1}".format(cls_true[i], cls_pred[i])

        ax.set_xlabel(xlabel)
        
        # Remove ticks from the plot.
        ax.set_xticks([])
        ax.set_yticks([])
        
    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()

加载少量数据查看上述函数是否正确

# Get the first images from the test-set.
images = data.test.images[0:9]

# Get the true classes for those images.
cls_true = data.test.cls[0:9]

# Plot the images and labels using our helper-function above.
plot_images(images=images, cls_true=cls_true)

输出结果如下:
在这里插入图片描述

6、Placeholder variables

占位符变量作为图表的输入,我们可以在每次执行图表时改变。

# Define the placeholder variable for the input images.
# None means that the tensor may hold an arbitrary number of images 
# with each image being a vector of length img_size_flat.
x = tf.placeholder(tf.float32, [None, img_size_flat])

# Define the placeholder variable for the true labels.
y_true = tf.placeholder(tf.float32, [None, num_classes])

# Define the placeholder variable for the true class.
y_true_cls = tf.placeholder(tf.int64, [None])

7、Variables to be optimized

# The first variable that must be optimized is called weights 
# and is defined here as a TensorFlow variable that must be initialized with zeros 
# and whose shape is [img_size_flat, num_classes], 
# so it is a 2-dimensional tensor (or matrix) 
# with img_size_flat rows and num_classes columns.
weights = tf.Variable(tf.zeros([img_size_flat, num_classes]))

# The second variable that must be optimized is called biases 
# and is defined as a 1-dimensional tensor (or vector) of length num_classes.
biases = tf.Variable(tf.zeros([num_classes]))

8、Model

# This simple mathematical model multiplies the images 
# in the placeholder variable x with the weights and then adds the biases.
logits = tf.matmul(x, weights) + biases
y_pred = tf.nn.softmax(logits)

y_pred_cls = tf.argmax(y_pred, axis=1)

9、Cost-function to be optimized

cross_entropy = tf.nn.softmax_cross_entropy_with_logits(logits=logits,
                                                        labels=y_true)
cost = tf.reduce_mean(cross_entropy)

10、Optimization method

optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.5).minimize(cost)

11、Performance measures

correct_prediction = tf.equal(y_pred_cls, y_true_cls)

accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

12、TensorFlow Run

session = tf.Session()

session.run(tf.global_variables_initializer())

batch_size = 100
def optimize(num_iterations):
    for i in range(num_iterations):
        # Get a batch of training examples.
        # x_batch now holds a batch of images and
        # y_true_batch are the true labels for those images.
        x_batch, y_true_batch = data.train.next_batch(batch_size)
        
        # Put the batch into a dict with the proper names
        # for placeholder variables in the TensorFlow graph.
        # Note that the placeholder for y_true_cls is not set
        # because it is not used during training.
        feed_dict_train = {x: x_batch,
                           y_true: y_true_batch}

        # Run the optimizer using this batch of training data.
        # TensorFlow assigns the variables in feed_dict_train
        # to the placeholder variables and then runs the optimizer.
        session.run(optimizer, feed_dict=feed_dict_train)

13、Helper-functions to show performance

feed_dict_test = {x: data.test.images,
                  y_true: data.test.labels,
                  y_true_cls: data.test.cls}
def print_accuracy():
    # Use TensorFlow to compute the accuracy.
    acc = session.run(accuracy, feed_dict=feed_dict_test)
    
    # Print the accuracy.
    print("Accuracy on test-set: {0:.1%}".format(acc))
def print_confusion_matrix():
    # Get the true classifications for the test-set.
    cls_true = data.test.cls
    
    # Get the predicted classifications for the test-set.
    cls_pred = session.run(y_pred_cls, feed_dict=feed_dict_test)

    # Get the confusion matrix using sklearn.
    cm = confusion_matrix(y_true=cls_true,
                          y_pred=cls_pred)

    # Print the confusion matrix as text.
    print(cm)

    # Plot the confusion matrix as an image.
    plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)

    # Make various adjustments to the plot.
    plt.tight_layout()
    plt.colorbar()
    tick_marks = np.arange(num_classes)
    plt.xticks(tick_marks, range(num_classes))
    plt.yticks(tick_marks, range(num_classes))
    plt.xlabel('Predicted')
    plt.ylabel('True')
    
    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()
def plot_example_errors():
    # Use TensorFlow to get a list of boolean values
    # whether each test-image has been correctly classified,
    # and a list for the predicted class of each image.
    correct, cls_pred = session.run([correct_prediction, y_pred_cls],
                                    feed_dict=feed_dict_test)

    # Negate the boolean array.
    incorrect = (correct == False)
    
    # Get the images from the test-set that have been
    # incorrectly classified.
    images = data.test.images[incorrect]
    
    # Get the predicted classes for those images.
    cls_pred = cls_pred[incorrect]

    # Get the true classes for those images.
    cls_true = data.test.cls[incorrect]
    
    # Plot the first 9 images.
    plot_images(images=images[0:9],
                cls_true=cls_true[0:9],
                cls_pred=cls_pred[0:9])
def plot_weights():
    # Get the values for the weights from the TensorFlow variable.
    w = session.run(weights)
    
    # Get the lowest and highest values for the weights.
    # This is used to correct the colour intensity across
    # the images so they can be compared with each other.
    w_min = np.min(w)
    w_max = np.max(w)

    # Create figure with 3x4 sub-plots,
    # where the last 2 sub-plots are unused.
    fig, axes = plt.subplots(3, 4)
    fig.subplots_adjust(hspace=0.3, wspace=0.3)

    for i, ax in enumerate(axes.flat):
        # Only use the weights for the first 10 sub-plots.
        if i<10:
            # Get the weights for the i'th digit and reshape it.
            # Note that w.shape == (img_size_flat, 10)
            image = w[:, i].reshape(img_shape)

            # Set the label for the sub-plot.
            ax.set_xlabel("Weights: {0}".format(i))

            # Plot the image.
            ax.imshow(image, vmin=w_min, vmax=w_max, cmap='seismic')

        # Remove ticks from each sub-plot.
        ax.set_xticks([])
        ax.set_yticks([])
        
    # Ensure the plot is shown correctly with multiple plots
    # in a single Notebook cell.
    plt.show()

14、Performance before any optimization

print_accuracy()

plot_example_errors()

15、Performance after 1 optimization iteration

optimize(num_iterations=1)

print_accuracy()

plot_example_errors()

plot_weights()

16、Performance after 10 optimization iterations

# We have already performed 1 iteration.
optimize(num_iterations=9)

print_accuracy()

plot_example_errors()

plot_weights()

17、Performance after 1000 optimization iterations

# We have already performed 10 iterations.
optimize(num_iterations=990)

print_accuracy()

plot_example_errors()

plot_weights()

print_confusion_matrix()

参考文献:
[1] TensorFlow 中文社区
[2] 吴恩达免费的深度学习网络课程,以及课后编程作业。
本文内容编辑:张永辉

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值