tensorflow1.x学习之 11-模型的保存与恢复

原链接

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\site-packages\tensorflow\python\framework\dtypes.py:517: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\site-packages\tensorflow\python\framework\dtypes.py:518: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\site-packages\tensorflow\python\framework\dtypes.py:519: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\site-packages\tensorflow\python\framework\dtypes.py:520: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\site-packages\tensorflow\python\framework\dtypes.py:521: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
C:\Users\chengyuanting\.conda\envs\tensorflow19\lib\site-packages\tensorflow\python\framework\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.
  np_resource = np.dtype([("resource", np.ubyte, 1)])
mnist = input_data.read_data_sets("MNIST",one_hot=True)
Extracting MNIST\train-images-idx3-ubyte.gz
Extracting MNIST\train-labels-idx1-ubyte.gz
Extracting MNIST\t10k-images-idx3-ubyte.gz
Extracting MNIST\t10k-labels-idx1-ubyte.gz
设置参数batch_size的大小,计算迭代的总批次
batch_size = 100
n_batches = mnist.train.num_examples // batch_size
搭建网络
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
w = tf.Variable(tf.truncated_normal([784,10],stddev = 0.1))
b = tf.Variable(tf.zeros([10])+ 0.1)
预测输出
prediction = tf.nn.softmax(tf.matmul(x,w) + b)
定义损失函数
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels = y,logits = prediction))
WARNING:tensorflow:From <ipython-input-10-7d3e41057cb4>:1: softmax_cross_entropy_with_logits (from tensorflow.python.ops.nn_ops) is deprecated and will be removed in a future version.
Instructions for updating:

Future major versions of TensorFlow will allow gradients to flow
into the labels input on backprop by default.

See tf.nn.softmax_cross_entropy_with_logits_v2.

定义优化器
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
计算正确率
correct_prediction = tf.equal(tf.argmax(prediction,1),tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
定义初始化的init
init = tf.global_variables_initializer()
定义保存对象
saver = tf.train.Saver()
训练模型并保存
with tf.Session() as sess:
    sess.run(init)
    for epoch in range(100):
        for batch in range(n_batches):
            batch_xs,batch_ys = mnist.train.next_batch(batch_size)
            sess.run([train_step],{x:batch_xs,y:batch_ys})
        saver.save(sess,save_path = 'saved_model/mymodel')
        acc,loss = sess.run([accuracy,cross_entropy],{x:mnist.test.images,y:mnist.test.labels})
        print("Iter: " + str(epoch) + " Loss: " + str(loss) + ", Testing Acc: " + str(acc))
Iter: 0 Loss: 1.9792926, Testing Acc: 0.656
Iter: 1 Loss: 1.8310652, Testing Acc: 0.7474
Iter: 2 Loss: 1.7690496, Testing Acc: 0.7743
Iter: 3 Loss: 1.7354207, Testing Acc: 0.7932
Iter: 4 Loss: 1.7137011, Testing Acc: 0.8054
Iter: 5 Loss: 1.6985549, Testing Acc: 0.8121
Iter: 6 Loss: 1.6872387, Testing Acc: 0.818
Iter: 7 Loss: 1.6786864, Testing Acc: 0.8209
Iter: 8 Loss: 1.6717945, Testing Acc: 0.8242
Iter: 9 Loss: 1.6663136, Testing Acc: 0.826
Iter: 10 Loss: 1.6616516, Testing Acc: 0.8271
Iter: 11 Loss: 1.6578391, Testing Acc: 0.8288
Iter: 12 Loss: 1.6547573, Testing Acc: 0.8295
Iter: 13 Loss: 1.6518422, Testing Acc: 0.8304
Iter: 14 Loss: 1.649382, Testing Acc: 0.8319
Iter: 15 Loss: 1.6475127, Testing Acc: 0.8327
Iter: 16 Loss: 1.6456047, Testing Acc: 0.8332
Iter: 17 Loss: 1.6439353, Testing Acc: 0.8341
Iter: 18 Loss: 1.6423808, Testing Acc: 0.8347
Iter: 19 Loss: 1.6409477, Testing Acc: 0.8357
Iter: 20 Loss: 1.6398621, Testing Acc: 0.8362
Iter: 21 Loss: 1.6387526, Testing Acc: 0.8371
Iter: 22 Loss: 1.6377356, Testing Acc: 0.8369
Iter: 23 Loss: 1.6367621, Testing Acc: 0.8377
Iter: 24 Loss: 1.635877, Testing Acc: 0.8386
Iter: 25 Loss: 1.635138, Testing Acc: 0.8381
Iter: 26 Loss: 1.6344916, Testing Acc: 0.8384
Iter: 27 Loss: 1.6336166, Testing Acc: 0.8396
Iter: 28 Loss: 1.6330918, Testing Acc: 0.8391
Iter: 29 Loss: 1.6324666, Testing Acc: 0.8403
Iter: 30 Loss: 1.6318762, Testing Acc: 0.8403
Iter: 31 Loss: 1.6312684, Testing Acc: 0.8405
Iter: 32 Loss: 1.6307572, Testing Acc: 0.8405
Iter: 33 Loss: 1.6303653, Testing Acc: 0.8405
Iter: 34 Loss: 1.6299387, Testing Acc: 0.8409
Iter: 35 Loss: 1.6294817, Testing Acc: 0.8407
Iter: 36 Loss: 1.6290295, Testing Acc: 0.8411
Iter: 37 Loss: 1.6287543, Testing Acc: 0.8412
Iter: 38 Loss: 1.6282248, Testing Acc: 0.8419
Iter: 39 Loss: 1.6279277, Testing Acc: 0.842
Iter: 40 Loss: 1.6275755, Testing Acc: 0.8423
Iter: 41 Loss: 1.6272346, Testing Acc: 0.8422
Iter: 42 Loss: 1.6268567, Testing Acc: 0.8436
Iter: 43 Loss: 1.6266078, Testing Acc: 0.8434
Iter: 44 Loss: 1.6263326, Testing Acc: 0.8432
Iter: 45 Loss: 1.6259912, Testing Acc: 0.8438
Iter: 46 Loss: 1.6257267, Testing Acc: 0.8438
Iter: 47 Loss: 1.6255258, Testing Acc: 0.8436
Iter: 48 Loss: 1.6252373, Testing Acc: 0.8439
Iter: 49 Loss: 1.6250653, Testing Acc: 0.8437
Iter: 50 Loss: 1.6249511, Testing Acc: 0.8436
Iter: 51 Loss: 1.6245471, Testing Acc: 0.8439
Iter: 52 Loss: 1.6242584, Testing Acc: 0.844
Iter: 53 Loss: 1.624265, Testing Acc: 0.8441
Iter: 54 Loss: 1.6240141, Testing Acc: 0.8441
Iter: 55 Loss: 1.623786, Testing Acc: 0.844
Iter: 56 Loss: 1.6236019, Testing Acc: 0.8443
Iter: 57 Loss: 1.6232818, Testing Acc: 0.8445
Iter: 58 Loss: 1.6230925, Testing Acc: 0.8446
Iter: 59 Loss: 1.622992, Testing Acc: 0.8449
Iter: 60 Loss: 1.6228051, Testing Acc: 0.8451
Iter: 61 Loss: 1.6224904, Testing Acc: 0.8452
Iter: 62 Loss: 1.6224169, Testing Acc: 0.8448
Iter: 63 Loss: 1.6222577, Testing Acc: 0.8451
Iter: 64 Loss: 1.6221057, Testing Acc: 0.8452
Iter: 65 Loss: 1.6220351, Testing Acc: 0.8454
Iter: 66 Loss: 1.6218963, Testing Acc: 0.8452
Iter: 67 Loss: 1.6217381, Testing Acc: 0.8463
Iter: 68 Loss: 1.6216105, Testing Acc: 0.8462
Iter: 69 Loss: 1.6214273, Testing Acc: 0.8464
Iter: 70 Loss: 1.6213533, Testing Acc: 0.8461
Iter: 71 Loss: 1.6212226, Testing Acc: 0.8469
Iter: 72 Loss: 1.6212087, Testing Acc: 0.8465
Iter: 73 Loss: 1.6209592, Testing Acc: 0.8467
Iter: 74 Loss: 1.620869, Testing Acc: 0.8467
Iter: 75 Loss: 1.6207072, Testing Acc: 0.847
Iter: 76 Loss: 1.620664, Testing Acc: 0.847
Iter: 77 Loss: 1.6205167, Testing Acc: 0.847
Iter: 78 Loss: 1.6202962, Testing Acc: 0.8472
Iter: 79 Loss: 1.6202666, Testing Acc: 0.8474
Iter: 80 Loss: 1.6202056, Testing Acc: 0.8471
Iter: 81 Loss: 1.620063, Testing Acc: 0.8477
Iter: 82 Loss: 1.6198978, Testing Acc: 0.8476
Iter: 83 Loss: 1.6198438, Testing Acc: 0.8476
Iter: 84 Loss: 1.6197109, Testing Acc: 0.8474
Iter: 85 Loss: 1.619636, Testing Acc: 0.8477
Iter: 86 Loss: 1.6196154, Testing Acc: 0.8473
Iter: 87 Loss: 1.6194584, Testing Acc: 0.8476
Iter: 88 Loss: 1.6193762, Testing Acc: 0.8476
Iter: 89 Loss: 1.61928, Testing Acc: 0.8474
Iter: 90 Loss: 1.6191859, Testing Acc: 0.8476
Iter: 91 Loss: 1.6191095, Testing Acc: 0.8475
Iter: 92 Loss: 1.6189919, Testing Acc: 0.848
Iter: 93 Loss: 1.6190045, Testing Acc: 0.8479
Iter: 94 Loss: 1.618973, Testing Acc: 0.8477
Iter: 95 Loss: 1.6188941, Testing Acc: 0.8477
Iter: 96 Loss: 1.6188017, Testing Acc: 0.848
Iter: 97 Loss: 1.6187375, Testing Acc: 0.8482
Iter: 98 Loss: 1.6186334, Testing Acc: 0.8478
Iter: 99 Loss: 1.618518, Testing Acc: 0.8481
未恢复参数的模型效果
  1. 未恢复参数的模型效果
  2. 完全恢复模型参数的效果
with tf.Session() as sess:
    sess.run(init)
    acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
    print(" Loss: " + str(loss) + ", Testing Acc: " + str(acc))
 Loss: 2.310623, Testing Acc: 0.1182

with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess, 'saved_model/mymodel')
    acc, loss = sess.run([accuracy, cross_entropy], {x:mnist.test.images, y:mnist.test.labels})
    print(" Loss: " + str(loss) + ", Testing Acc: " + str(acc))
INFO:tensorflow:Restoring parameters from saved_model/mymodel
 Loss: 1.618518, Testing Acc: 0.8481

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值