tensorflow1.x学习之 6-简单分类问题

原链接

前言

深度学习基本上就是处理两大类问题,分类问题与回归问题。系列文章的第4,5篇均是针对回归问题进行介绍,本文则会通过简单的示例,也是经典的tutorial——手写数字识别的demo来介绍分类任务的模型是如何搭建的。

知识点

mnist数据集,是tensorflow中自带的教学数据集,数据操作已经写好了,但是数据需要下载。其中每条数据是28×28的灰度图。图片上显示的各种数字(0-9)的手写体。
利用这个数据集可以做图片的10分类将达到识别数字手写体。

tf.nn.softmax()是对一个向量做softmax的操作。何为softmax操作即利用如下的计算式对列表中的每一个元素进行计算。

e i ∑ i = 0 n e i \frac{e^i}{\sum_{i=0}^{n}e^i} i=0neiei

上式中的i代表列表中的具体的数值。

tf.nn.softmax_cross_entropy_with_logits()这个函数可以求解预测值与真实值的交叉熵,其中参数logits代表预测的值,labels代表真实值。


注意计算loss的时候需要将tf.nn.softmax_cross_entropy_with_logits()的结果经过tf.reduce_mean()才能得到最终的平均loss,否则会得到一个列表,包含了一个批次中每一个训练样本的loss值。

tf.argmax()函数的作用是,将张量沿着某一个轴(维度)进行取最大值下标的操作。

tf.argmax([[0.1, 0.3, 0.6],[0.2, 0.5, 0.3]], 1)
>>> [2, 1]

tf.equal(a, b)函数用于判断张量a与张量b中哪些元素相等,最终生成一个与a,b相等的张量。a与b相等的位置为True,其余位置为False。

通过tf.equal(tf.argmax(prediction, 1), tf.argmax(y, 1))可以得到预测正确率的矩阵。

tf.cast()可以将传入的张量类型进行转换(类似于强制类型转换)。

accuarcy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# tf.cast()把correct_prediction中为True的转为1.0,False转为0.0
# 利用 tf.reduce_mean()求平均后就是整个训练集的正确率

简单分类问题

利用mnist数据集实现多选一的分类问题

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\importlib\_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192
  return f(*args, **kwds)
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\importlib\_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 192 from C header, got 216 from PyObject
  return f(*args, **kwds)
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\importlib\_bootstrap.py:219: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility. Expected 216, got 192
  return f(*args, **kwds)
载入数据集
mnist = input_data.read_data_sets("MNIST",one_hot=True)
Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting MNIST\train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting MNIST\train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting MNIST\t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting MNIST\t10k-labels-idx1-ubyte.gz
设置batch_size的大小
batch_size = 50
n_batchs = mnist.train.num_examples // batch_size
n_batchs
1100
定义两个placeholder作为数据的入口
x = tf.placeholder(tf.float32,[None, 784],name="x-input")
y = tf.placeholder(tf.float32,[None, 10],name="y-input")
x,y
(<tf.Tensor 'x-input:0' shape=(?, 784) dtype=float32>,
 <tf.Tensor 'y-input:0' shape=(?, 10) dtype=float32>)
创建隐藏层网络
w = tf.Variable(tf.zeros([784,10]))
w
<tf.Variable 'Variable:0' shape=(784, 10) dtype=float32_ref>
b = tf.Variable(tf.zeros([1,10]))
b
<tf.Variable 'Variable_1:0' shape=(1, 10) dtype=float32_ref>
prediction = tf.nn.softmax(tf.matmul(x,w) + b)
prediction
<tf.Tensor 'Softmax:0' shape=(?, 10) dtype=float32>
创建交叉熵代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=prediction, labels = y))
WARNING:tensorflow:From <ipython-input-13-ca5e4ad4b636>:1: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.

定义优化器
train_step = tf.train.GradientDescentOptimizer(0.01).minimize(loss)
train_step
<tf.Operation 'GradientDescent' type=NoOp>
初始化全局变量
init = tf.global_variables_initializer()
计算准确率
correct_prediction = tf.equal(tf.argmax(prediction, 1),tf.argmax(y,1))
accuarcy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
correct_prediction,accuarcy
(<tf.Tensor 'Equal:0' shape=(?,) dtype=bool>,
 <tf.Tensor 'Mean_1:0' shape=() dtype=float32>)
训练
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.333)
sess = tf.Session(config = tf.ConfigProto(gpu_options = gpu_options))
sess.run(init)
for epoch in range(200):
    for batch in range(n_batchs):
        batch_x,batch_y = mnist.train.next_batch(batch_size)
        sess.run([train_step],{x:batch_x,y:batch_y})
    acc = sess.run(accuarcy,feed_dict = {x:mnist.test.images,y:mnist.test.labels})
    print("Iter:",epoch,"acc:",acc)
Iter: 0 acc: 0.6188
Iter: 1 acc: 0.7683
Iter: 2 acc: 0.7967
Iter: 3 acc: 0.806
Iter: 4 acc: 0.8116
Iter: 5 acc: 0.816
Iter: 6 acc: 0.8195
Iter: 7 acc: 0.8212
Iter: 8 acc: 0.8234
Iter: 9 acc: 0.8255
Iter: 10 acc: 0.8264
Iter: 11 acc: 0.828
Iter: 12 acc: 0.83
Iter: 13 acc: 0.8428
Iter: 14 acc: 0.8611
Iter: 15 acc: 0.8705
Iter: 16 acc: 0.8763
Iter: 17 acc: 0.8827
Iter: 18 acc: 0.8881
Iter: 19 acc: 0.8909
Iter: 20 acc: 0.8935
Iter: 21 acc: 0.8952
Iter: 22 acc: 0.8965
Iter: 23 acc: 0.8971
Iter: 24 acc: 0.8985
Iter: 25 acc: 0.8982
Iter: 26 acc: 0.9
Iter: 27 acc: 0.9
Iter: 28 acc: 0.9003
Iter: 29 acc: 0.901
Iter: 30 acc: 0.902
Iter: 31 acc: 0.9022
Iter: 32 acc: 0.9028
Iter: 33 acc: 0.9035
Iter: 34 acc: 0.9038
Iter: 35 acc: 0.9042
Iter: 36 acc: 0.9046
Iter: 37 acc: 0.9053
Iter: 38 acc: 0.9055
Iter: 39 acc: 0.9055
Iter: 40 acc: 0.9057
Iter: 41 acc: 0.9066
Iter: 42 acc: 0.9063
Iter: 43 acc: 0.9067
Iter: 44 acc: 0.9076
Iter: 45 acc: 0.9078
Iter: 46 acc: 0.9084
Iter: 47 acc: 0.9083
Iter: 48 acc: 0.9088
Iter: 49 acc: 0.9093
Iter: 50 acc: 0.9095
Iter: 51 acc: 0.9093
Iter: 52 acc: 0.9093
Iter: 53 acc: 0.9094
Iter: 54 acc: 0.9092
Iter: 55 acc: 0.9091
Iter: 56 acc: 0.9092
Iter: 57 acc: 0.9097
Iter: 58 acc: 0.9091
Iter: 59 acc: 0.9093
Iter: 60 acc: 0.9098
Iter: 61 acc: 0.9099
Iter: 62 acc: 0.9103
Iter: 63 acc: 0.9105
Iter: 64 acc: 0.9106
Iter: 65 acc: 0.9106
Iter: 66 acc: 0.9109
Iter: 67 acc: 0.9113
Iter: 68 acc: 0.9115
Iter: 69 acc: 0.9114
Iter: 70 acc: 0.9114
Iter: 71 acc: 0.912
Iter: 72 acc: 0.9122
Iter: 73 acc: 0.9123
Iter: 74 acc: 0.9126
Iter: 75 acc: 0.9129
Iter: 76 acc: 0.913
Iter: 77 acc: 0.9131
Iter: 78 acc: 0.9133
Iter: 79 acc: 0.9132
Iter: 80 acc: 0.9131
Iter: 81 acc: 0.9136
Iter: 82 acc: 0.9137
Iter: 83 acc: 0.9138
Iter: 84 acc: 0.9141
Iter: 85 acc: 0.9143
Iter: 86 acc: 0.9143
Iter: 87 acc: 0.9146
Iter: 88 acc: 0.9148
Iter: 89 acc: 0.9152
Iter: 90 acc: 0.915
Iter: 91 acc: 0.9153
Iter: 92 acc: 0.9152
Iter: 93 acc: 0.9156
Iter: 94 acc: 0.9155
Iter: 95 acc: 0.9154
Iter: 96 acc: 0.9158
Iter: 97 acc: 0.9156
Iter: 98 acc: 0.9158
Iter: 99 acc: 0.9159
Iter: 100 acc: 0.9163
Iter: 101 acc: 0.9168
Iter: 102 acc: 0.9165
Iter: 103 acc: 0.9169
Iter: 104 acc: 0.917
Iter: 105 acc: 0.9169
Iter: 106 acc: 0.917
Iter: 107 acc: 0.917
Iter: 108 acc: 0.9174
Iter: 109 acc: 0.9176
Iter: 110 acc: 0.9174
Iter: 111 acc: 0.9175
Iter: 112 acc: 0.9177
Iter: 113 acc: 0.9174
Iter: 114 acc: 0.9177
Iter: 115 acc: 0.9177
Iter: 116 acc: 0.9181
Iter: 117 acc: 0.9178
Iter: 118 acc: 0.9179
Iter: 119 acc: 0.918
Iter: 120 acc: 0.9183
Iter: 121 acc: 0.9183
Iter: 122 acc: 0.9183
Iter: 123 acc: 0.9183
Iter: 124 acc: 0.9188
Iter: 125 acc: 0.9192
Iter: 126 acc: 0.9189
Iter: 127 acc: 0.9189
Iter: 128 acc: 0.9189
Iter: 129 acc: 0.9193
Iter: 130 acc: 0.9193
Iter: 131 acc: 0.9195
Iter: 132 acc: 0.9195
Iter: 133 acc: 0.9194
Iter: 134 acc: 0.919
Iter: 135 acc: 0.9196
Iter: 136 acc: 0.9194
Iter: 137 acc: 0.9194
Iter: 138 acc: 0.9192
Iter: 139 acc: 0.9194
Iter: 140 acc: 0.9194
Iter: 141 acc: 0.9195
Iter: 142 acc: 0.9196
Iter: 143 acc: 0.9196
Iter: 144 acc: 0.9201
Iter: 145 acc: 0.9194
Iter: 146 acc: 0.9198
Iter: 147 acc: 0.9198
Iter: 148 acc: 0.9197
Iter: 149 acc: 0.9199
Iter: 150 acc: 0.9198
Iter: 151 acc: 0.92
Iter: 152 acc: 0.92
Iter: 153 acc: 0.9199
Iter: 154 acc: 0.9199
Iter: 155 acc: 0.9199
Iter: 156 acc: 0.92
Iter: 157 acc: 0.9203
Iter: 158 acc: 0.9205
Iter: 159 acc: 0.9205
Iter: 160 acc: 0.9207
Iter: 161 acc: 0.9204
Iter: 162 acc: 0.9207
Iter: 163 acc: 0.9204
Iter: 164 acc: 0.9203
Iter: 165 acc: 0.9206
Iter: 166 acc: 0.9207
Iter: 167 acc: 0.9204
Iter: 168 acc: 0.9207
Iter: 169 acc: 0.9205
Iter: 170 acc: 0.9207
Iter: 171 acc: 0.9205
Iter: 172 acc: 0.9208
Iter: 173 acc: 0.921
Iter: 174 acc: 0.9209
Iter: 175 acc: 0.9208
Iter: 176 acc: 0.9208
Iter: 177 acc: 0.9209
Iter: 178 acc: 0.9208
Iter: 179 acc: 0.9211
Iter: 180 acc: 0.921
Iter: 181 acc: 0.9209
Iter: 182 acc: 0.921
Iter: 183 acc: 0.921
Iter: 184 acc: 0.921
Iter: 185 acc: 0.9211
Iter: 186 acc: 0.9211
Iter: 187 acc: 0.921
Iter: 188 acc: 0.9212
Iter: 189 acc: 0.9214
Iter: 190 acc: 0.9215
Iter: 191 acc: 0.9213
Iter: 192 acc: 0.9208
Iter: 193 acc: 0.9212
Iter: 194 acc: 0.9214
Iter: 195 acc: 0.9212
Iter: 196 acc: 0.9215
Iter: 197 acc: 0.9217
Iter: 198 acc: 0.9216
Iter: 199 acc: 0.9212

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值