tensorflow实例(10)--模型保存与读取,鸢尾花神经网络模型读写

TensorFlow通过tf.train.Saver类实现模型的保存和提取。
这样做的好处就如我们在用神经网络训练一个模型,需要非常长的时间,当把这个模型保存后,
当要测试或使用这个模型时,把保存的模型提取出来,不需要再次的进行训练
关于神经网络的基本原理与tensorflow的实现可参考
机器学习(1)--神经网络初探
 TensorFlow实例(4)--MNIST简介及手写数字分类算法

本例分两段代码,分别放在两个py文件中,一个用于训练与保存模型,一个用于测试模型
使用数据集简介:以鸢尾花的特征作为数据,共有数据集包含150个数据集,
分为3类setosa(山鸢尾), versicolor(变色鸢尾), virginica(维吉尼亚鸢尾)
每类50个数据,每条数据包含4个属性数据 和 一个类别数据.
会把这150个数据集分成为100个训练数据与50个测试数据,分别放在两个文件中,
当训练PYTHON程序运行后,会在指定目录下生成一个模型文件,而测试PYTHON程序只通过调用模型文件进行测试
因为采用的是神经网络的计算方式,所以对setosa(山鸢尾), versicolor(变色鸢尾), virginica(维吉尼亚鸢尾)的数据类型定义
setosa(山鸢尾)[1,0,0],versicolor(变色鸢尾)[0,1,0],virginica(维吉尼亚鸢尾)[0,0,1]

其它应用鸢尾花数据的机器学习实例可参考
机器学习(6)--朴素贝叶斯模型算法之鸢尾花数据实验
机器学习(3.2)--PCA降维鸢尾花数据降维演示

机器学习(2)--邻近算法(KNN)

代码段一:

#-*- coding:utf-8 -*-
import numpy as np
import tensorflow as tf 
train = [['5.1', '3.5', '1.4', '0.2', 'Iris-setosa'], ['4.9', '3.0', '1.4', '0.2', 'Iris-setosa'], ['4.7', '3.2', '1.3', '0.2', 'Iris-setosa'], ['5.0', '3.6', '1.4', '0.2', 'Iris-setosa'], ['5.4', '3.9', '1.7', '0.4', 'Iris-setosa'], ['4.6', '3.4', '1.4', '0.3', 'Iris-setosa'], ['4.9', '3.1', '1.5', '0.1', 'Iris-setosa'], ['4.8', '3.0', '1.4', '0.1', 'Iris-setosa'], ['4.3', '3.0', '1.1', '0.1', 'Iris-setosa'], ['5.8', '4.0', '1.2', '0.2', 'Iris-setosa'], ['5.1', '3.5', '1.4', '0.3', 'Iris-setosa'], ['5.7', '3.8', '1.7', '0.3', 'Iris-setosa'], ['5.1', '3.8', '1.5', '0.3', 'Iris-setosa'], ['4.6', '3.6', '1.0', '0.2', 'Iris-setosa'], ['5.1', '3.3', '1.7', '0.5', 'Iris-setosa'], ['5.2', '3.5', '1.5', '0.2', 'Iris-setosa'], ['5.2', '3.4', '1.4', '0.2', 'Iris-setosa'], ['4.7', '3.2', '1.6', '0.2', 'Iris-setosa'], ['4.8', '3.1', '1.6', '0.2', 'Iris-setosa'], ['5.4', '3.4', '1.5', '0.4', 'Iris-setosa'], ['5.2', '4.1', '1.5', '0.1', 'Iris-setosa'], ['5.0', '3.2', '1.2', '0.2', 'Iris-setosa'], ['5.5', '3.5', '1.3', '0.2', 'Iris-setosa'], ['4.9', '3.1', '1.5', '0.1', 'Iris-setosa'], ['4.4', '3.0', '1.3', '0.2', 'Iris-setosa'], ['5.1', '3.4', '1.5', '0.2', 'Iris-setosa'], ['5.0', '3.5', '1.3', '0.3', 'Iris-setosa'], ['4.4', '3.2', '1.3', '0.2', 'Iris-setosa'], ['5.0', '3.5', '1.6', '0.6', 'Iris-setosa'], ['4.8', '3.0', '1.4', '0.3', 'Iris-setosa'], ['5.1', '3.8', '1.6', '0.2', 'Iris-setosa'], ['4.6', '3.2', '1.4', '0.2', 'Iris-setosa'], ['5.3', '3.7', '1.5', '0.2', 'Iris-setosa'], ['5.0', '3.3', '1.4', '0.2', 'Iris-setosa'], ['7.0', '3.2', '4.7', '1.4', 'Iris-versicolor'], ['6.4', '3.2', '4.5', '1.5', 'Iris-versicolor'], ['6.9', '3.1', '4.9', '1.5', 'Iris-versicolor'], ['5.5', '2.3', '4.0', '1.3', 'Iris-versicolor'], ['6.5', '2.8', '4.6', '1.5', 'Iris-versicolor'], ['5.7', '2.8', '4.5', '1.3', 'Iris-versicolor'], ['6.3', '3.3', '4.7', '1.6', 'Iris-versicolor'], ['4.9', '2.4', '3.3', '1.0', 'Iris-versicolor'], ['6.0', '2.2', '4.0', '1.0', 'Iris-versicolor'], ['6.1', '2.9', '4.7', '1.4', 'Iris-versicolor'], ['5.6', '2.9', '3.6', '1.3', 'Iris-versicolor'], ['6.7', '3.1', '4.4', '1.4', 'Iris-versicolor'], ['5.6', '3.0', '4.5', '1.5', 'Iris-versicolor'], ['5.9', '3.2', '4.8', '1.8', 'Iris-versicolor'], ['6.1', '2.8', '4.0', '1.3', 'Iris-versicolor'], ['6.3', '2.5', '4.9', '1.5', 'Iris-versicolor'], ['6.4', '2.9', '4.3', '1.3', 'Iris-versicolor'], ['6.6', '3.0', '4.4', '1.4', 'Iris-versicolor'], ['6.8', '2.8', '4.8', '1.4', 'Iris-versicolor'], ['6.7', '3.0', '5.0', '1.7', 'Iris-versicolor'], ['5.7', '2.6', '3.5', '1.0', 'Iris-versicolor'], ['5.5', '2.4', '3.7', '1.0', 'Iris-versicolor'], ['5.8', '2.7', '3.9', '1.2', 'Iris-versicolor'], ['5.4', '3.0', '4.5', '1.5', 'Iris-versicolor'], ['6.7', '3.1', '4.7', '1.5', 'Iris-versicolor'], ['6.3', '2.3', '4.4', '1.3', 'Iris-versicolor'], ['5.6', '3.0', '4.1', '1.3', 'Iris-versicolor'], ['5.5', '2.5', '4.0', '1.3', 'Iris-versicolor'], ['5.5', '2.6', '4.4', '1.2', 'Iris-versicolor'], ['6.1', '3.0', '4.6', '1.4', 'Iris-versicolor'], ['5.8', '2.6', '4.0', '1.2', 'Iris-versicolor'], ['5.0', '2.3', '3.3', '1.0', 'Iris-versicolor'], ['5.6', '2.7', '4.2', '1.3', 'Iris-versicolor'], ['5.7', '3.0', '4.2', '1.2', 'Iris-versicolor'], ['6.2', '2.9', '4.3', '1.3', 'Iris-versicolor'], ['5.1', '2.5', '3.0', '1.1', 'Iris-versicolor'], ['6.3', '3.3', '6.0', '2.5', 'Iris-virginica'], ['7.1', '3.0', '5.9', '2.1', 'Iris-virginica'], ['6.3', '2.9', '5.6', '1.8', 'Iris-virginica'], ['6.5', '3.0', '5.8', '2.2', 'Iris-virginica'], ['7.6', '3.0', '6.6', '2.1', 'Iris-virginica'], ['4.9', '2.5', '4.5', '1.7', 'Iris-virginica'], ['7.3', '2.9', '6.3', '1.8', 'Iris-virginica'], ['6.7', '2.5', '5.8', '1.8', 'Iris-virginica'], ['7.2', '3.6', '6.1', '2.5', 'Iris-virginica'], ['6.8', '3.0', '5.5', '2.1', 'Iris-virginica'], ['5.8', '2.8', '5.1', '2.4', 'Iris-virginica'], ['7.7', '3.8', '6.7', '2.2', 'Iris-virginica'], ['7.7', '2.6', '6.9', '2.3', 'Iris-virginica'], ['6.0', '2.2', '5.0', '1.5', 'Iris-virginica'], ['6.3', '2.7', '4.9', '1.8', 'Iris-virginica'], ['7.2', '3.2', '6.0', '1.8', 'Iris-virginica'], ['6.4', '2.8', '5.6', '2.1', 'Iris-virginica'], ['7.4', '2.8', '6.1', '1.9', 'Iris-virginica'], ['6.4', '2.8', '5.6', '2.2', 'Iris-virginica'], ['6.3', '2.8', '5.1', '1.5', 'Iris-virginica'], ['6.1', '2.6', '5.6', '1.4', 'Iris-virginica'], ['6.3', '3.4', '5.6', '2.4', 'Iris-virginica'], ['6.4', '3.1', '5.5', '1.8', 'Iris-virginica'], ['6.0', '3.0', '4.8', '1.8', 'Iris-virginica'], ['6.9', '3.1', '5.4', '2.1', 'Iris-virginica'], ['5.8', '2.7', '5.1', '1.9', 'Iris-virginica'], ['6.7', '3.3', '5.7', '2.5', 'Iris-virginica'], ['6.3', '2.5', '5.0', '1.9', 'Iris-virginica'], ['6.2', '3.4', '5.4', '2.3', 'Iris-virginica'], ['5.9', '3.0', '5.1', '1.8', 'Iris-virginica']]
data=train

'''创建变量(一直到#---------------------这前的部份),这个函数在训练程序与测练程序使用完全相同
    这里包含两大部分,共七个变量
    1、将train变为x_data,y_data两部份,如原始数据中每条格式为['5.1', '3.5', '1.4', '0.2', 'Iris-setosa']
        1.1 x_data为['5.1', '3.5', '1.4', '0.2']组成的四个特征数据
        1.2 y_data为类型数据,setosa(山鸢尾)[1,0,0],versicolor(变色鸢尾)[0,1,0],virginica(维吉尼亚鸢尾)[0,0,1]
    2、tensorflow所需要的张量
        2.1 w,b神经网络所需要的weight与biase
        2.2 x_pl,y_pl tensorflow 计算过程的传入量
        2.3 y 输入量
'''

x_data = np.array([x[0:-1] for x in data]).astype(np.float32)
y_data= np.array([[1,0,0] if x[-1] == 'Iris-setosa' else  [0,1,0] if x[-1] == 'Iris-versicolor' else [0,0,1] for x in data]).astype(np.float32)

x_pl=tf.placeholder(tf.float32,[None,x_data.shape[1]],name='x_pl')
y_pl=tf.placeholder(tf.float32,[None,y_data.shape[1]],name='y_pl')

w=tf.Variable(tf.random_uniform([x_data.shape[1],y_data.shape[1]],0,1),'w')
b=tf.Variable(tf.random_uniform([y_data.shape[1]],0,1),'b')

y=tf.nn.softmax(tf.matmul(x_pl,w)+b)
#---------------------

#训练所需要的模型,如果是在测试文件,则不需要这些了,需要的统计正确率的模型
loss=-tf.reduce_mean(y_pl*tf.log(y))
train=tf.train.GradientDescentOptimizer(0.005).minimize(loss)

sess=tf.Session()
sess.run(tf.global_variables_initializer())
for i in range(10000):
    sess.run(train,feed_dict={x_pl:x_data,y_pl:y_data})
    if i %1000 ==0 :print(i,sess.run(loss,feed_dict={x_pl:x_data,y_pl:y_data}))

saver=tf.train.Saver()
try:
    saver.save(sess,r'd:\tensorflowCkpt\Iris.ckpt') #文件名根据自己的实际情况进行修改
    sess.close()
    print('模型保存成功')
except :
    print('保存文件失败,请设置好路径,确只路径存在')
    sess.close()

代码段2

#-*- coding:utf-8 -*-  
import numpy as np
import tensorflow as tf 
test=[['4.6', '3.1', '1.5', '0.2', 'Iris-setosa'], ['5.0', '3.4', '1.5', '0.2', 'Iris-setosa'], ['4.4', '2.9', '1.4', '0.2', 'Iris-setosa'], ['5.4', '3.7', '1.5', '0.2', 'Iris-setosa'], ['4.8', '3.4', '1.6', '0.2', 'Iris-setosa'], ['5.7', '4.4', '1.5', '0.4', 'Iris-setosa'], ['5.4', '3.9', '1.3', '0.4', 'Iris-setosa'], ['5.4', '3.4', '1.7', '0.2', 'Iris-setosa'], ['5.1', '3.7', '1.5', '0.4', 'Iris-setosa'], ['4.8', '3.4', '1.9', '0.2', 'Iris-setosa'], ['5.0', '3.0', '1.6', '0.2', 'Iris-setosa'], ['5.0', '3.4', '1.6', '0.4', 'Iris-setosa'], ['5.5', '4.2', '1.4', '0.2', 'Iris-setosa'], ['4.9', '3.1', '1.5', '0.1', 'Iris-setosa'], ['4.5', '2.3', '1.3', '0.3', 'Iris-setosa'], ['5.1', '3.8', '1.9', '0.4', 'Iris-setosa'], ['6.6', '2.9', '4.6', '1.3', 'Iris-versicolor'], ['5.2', '2.7', '3.9', '1.4', 'Iris-versicolor'], ['5.0', '2.0', '3.5', '1.0', 'Iris-versicolor'], ['5.9', '3.0', '4.2', '1.5', 'Iris-versicolor'], ['5.8', '2.7', '4.1', '1.0', 'Iris-versicolor'], ['6.2', '2.2', '4.5', '1.5', 'Iris-versicolor'], ['5.6', '2.5', '3.9', '1.1', 'Iris-versicolor'], ['6.1', '2.8', '4.7', '1.2', 'Iris-versicolor'], ['6.0', '2.9', '4.5', '1.5', 'Iris-versicolor'], ['5.5', '2.4', '3.8', '1.1', 'Iris-versicolor'], ['6.0', '2.7', '5.1', '1.6', 'Iris-versicolor'], ['6.0', '3.4', '4.5', '1.6', 'Iris-versicolor'], ['5.7', '2.9', '4.2', '1.3', 'Iris-versicolor'], ['5.7', '2.8', '4.1', '1.3', 'Iris-versicolor'], ['5.8', '2.7', '5.1', '1.9', 'Iris-virginica'], ['6.5', '3.2', '5.1', '2.0', 'Iris-virginica'], ['6.4', '2.7', '5.3', '1.9', 'Iris-virginica'], ['5.7', '2.5', '5.0', '2.0', 'Iris-virginica'], ['6.4', '3.2', '5.3', '2.3', 'Iris-virginica'], ['6.5', '3.0', '5.5', '1.8', 'Iris-virginica'], ['6.9', '3.2', '5.7', '2.3', 'Iris-virginica'], ['5.6', '2.8', '4.9', '2.0', 'Iris-virginica'], ['7.7', '2.8', '6.7', '2.0', 'Iris-virginica'], ['6.7', '3.3', '5.7', '2.1', 'Iris-virginica'], ['6.2', '2.8', '4.8', '1.8', 'Iris-virginica'], ['6.1', '3.0', '4.9', '1.8', 'Iris-virginica'], ['7.2', '3.0', '5.8', '1.6', 'Iris-virginica'], ['7.9', '3.8', '6.4', '2.0', 'Iris-virginica'], ['7.7', '3.0', '6.1', '2.3', 'Iris-virginica'], ['6.7', '3.1', '5.6', '2.4', 'Iris-virginica'], ['6.9', '3.1', '5.1', '2.3', 'Iris-virginica'], ['6.8', '3.2', '5.9', '2.3', 'Iris-virginica'], ['6.7', '3.0', '5.2', '2.3', 'Iris-virginica'], ['6.5', '3.0', '5.2', '2.0', 'Iris-virginica']]
data=test


'''创建变量(一直到#---------------------这前的部份),这个函数在训练程序与测练程序使用完全相同
    这里包含两大部分,共七个变量
    1、将train变为x_data,y_data两部份,如原始数据中每条格式为['5.1', '3.5', '1.4', '0.2', 'Iris-setosa']
        1.1 x_data为['5.1', '3.5', '1.4', '0.2']组成的四个特征数据
        1.2 y_data为类型数据,setosa(山鸢尾)[1,0,0],versicolor(变色鸢尾)[0,1,0],virginica(维吉尼亚鸢尾)[0,0,1]
    2、tensorflow所需要的张量
        2.1 w,b神经网络所需要的weight与biase
        2.2 x_pl,y_pl tensorflow 计算过程的传入量
        2.3 y 输入量
'''
x_data = np.array([x[0:-1] for x in data]).astype(np.float32)
y_data= np.array([[1,0,0] if x[-1] == 'Iris-setosa' else  [0,1,0] if x[-1] == 'Iris-versicolor' else [0,0,1] for x in data]).astype(np.float32)

x_pl=tf.placeholder(tf.float32,[None,x_data.shape[1]],name='x_pl')
y_pl=tf.placeholder(tf.float32,[None,y_data.shape[1]],name='y_pl')

w=tf.Variable(tf.random_uniform([x_data.shape[1],y_data.shape[1]],0,1),'w')
b=tf.Variable(tf.random_uniform([y_data.shape[1]],0,1),'b')

y=tf.nn.softmax(tf.matmul(x_pl,w)+b)
#---------------------

sess=tf.Session()
sess.run(tf.global_variables_initializer())
isRestore = True
saver=tf.train.Saver()
try:
    saver.restore(sess,r'd:\tensorflowCkpt\Iris.ckpt') #文件名根据自己的实际情况进行修改
except :
    print('调用文件失败,请确保模型文件存在!')
    isRestore=False
if isRestore:
	print("正确率为 : "+str(round((sess.run(tf.reduce_mean(tf.cast(tf.equal(tf.argmax(y_data,axis=1),tf.argmax(sess.run(y,feed_dict={x_pl:x_data}),axis=1)),tf.float32))))*100,2)) + "%")
sess.close()

  • 3
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值