tensorflow卷积神经网络+mnist数据集

本博客参考:《tensorflow+keras深度学习人工智能实践应用 林大贵著》

加载数据集

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
path=r'E:\mydataset\data\LeYun_mnist'
mnist=input_data.read_data_sets(path,one_hot=True)

共享函数建立

def weight(shape):
    return tf.Variable(tf.truncated_normal(shape,stddev=0.1),name='w')
def bias(shape):
    return tf.Variable(tf.constant(0.1,shape=shape),name='b')
def conv2d(x,w):
    return tf.nn.conv2d(x,w,strides=[1,1,1,1],padding='SAME')
def maxpool_2x2(x):
    return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME' )

d

卷积网络模型建立

输入层

with tf.name_scope('input_layer'):
    x=tf.placeholder('float',shape=[None,784],name='x')
    x_image=tf.reshape(x,[-1,28,28,1])

卷积层1和池化层1

with tf.name_scope('c1_conv'):
    w1=weight([5,5,1,16])
    b1=bias([16])
    conv1=conv2d(x_image,w1)+b1
    c1_conv=tf.nn.relu(conv1)

卷积层2和池化层2

with tf.name_scope('c2_conv'):
    w2=weight([5,5,16,36])
    b2=bias([36])
    conv2=conv2d(c1_pool,w2)+b2
    c2_conv=tf.nn.relu(conv2)
with tf.name_scope('c2_pool'):
    c2_pool=maxpool_2x2(c2_conv)

展开层

with tf.name_scope('d_flat'):
    d_flat=tf.reshape(c2_pool,[-1,1764])

隐藏层(全连接层)

with tf.name_scope('d_hid_layer'):
    w3=weight([1764,128])
    b3=bias([128])
    d_hidden=tf.nn.relu(tf.matmul(d_flat,w3)+b3)
    d_hidden_dropout=tf.nn.dropout(d_hidden,keep_prob=0.8)

输出层

with tf.name_scope('output_layer'):
    w4=weight([128,10])
    b4=bias([10])
    y_predict=tf.nn.softmax(tf.matmul(d_hidden_dropout,w4)+b4)

优化器

with tf.name_scope('optimizer'):
    y_label=tf.placeholder('float',shape=[None,10],name='y_label')
    loss_fun=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y_predict,labels=y_label))
    optimizer=tf.train.AdamOptimizer(learning_rate=0.0001).minimize(loss_fun)

模型评估函数

with tf.name_scope('evaluate_model'):
    correct_pred=tf.equal(tf.argmax(y_predict,1),tf.argmax(y_label,1))
    acc=tf.reduce_mean(tf.cast(correct_pred,'float'))

模型训练参数定义

epochs=10
batch_size=100
total_batch=int(mnist.train.num_examples/batch_size)
epoch_size=[]
acc_list=[]
loss_list=[]

sess=tf.Session()
sess.run(tf.global_variables_initializer())

训练开始

for epoch in range(epochs):
    print(epoch)
    for i in range(total_batch):
        if (i+1)%250==0:
            print(i)
        x_train,y_train=mnist.train.next_batch(batch_size)
        sess.run(optimizer,feed_dict={x:x_train,y_label:y_train})
    loss,acc1=sess.run([loss_fun,acc],feed_dict={x:mnist.validation.images,y_label:mnist.validation.labels})
    epoch_size.append(epoch)
    loss_list.append(loss)
    acc_list.append(acc1)
    print(epoch+1,loss,acc1)

运行结果
在这里插入图片描述

模型训练结果可视化

import matplotlib.pyplot as plt

fig=plt.gcf()
plt.plot(range(1,11),acc_list,label='accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(loc='upper left')

运行结果
在这里插入图片描述

fig=plt.gcf()
plt.plot(range(1,11),loss_list,label='loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper left')

运行结果

fig=plt.gcf()
plt.plot(range(1,11),loss_list,label='loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(loc='upper left')

在这里插入图片描述

模型评分

score=sess.run(acc,feed_dict={x:mnist.test.images,y_label:mnist.test.labels})
score
>>> 0.9751

利用模型进行预测

import numpy as np
pred=sess.run(tf.argmax(y_predict,1),feed_dict={x:mnist.test.images})
pred[:10]
>>> array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], dtype=int64)
real_label=np.argmax(mnist.test.labels,axis=-1)
real_label[:10]
>>> array([7, 2, 1, 0, 4, 1, 4, 9, 5, 9], dtype=int64)

显示混淆矩阵

import pandas as pd
real_label=real_label.reshape(-1)
pd.crosstab(real_label,pred,rownames=['label'],colnames=['predict'])

运行结果
在这里插入图片描述

模型保存

saver=tf.train.Saver()
save_path1=r'models/conv_mnist1.h5'
saver.save(sess, save_path=save_path1)
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

夺笋123

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值