tensorflow--模型的保存和提取

参考:

TensorFlow:保存和提取模型
最全Tensorflow模型保存和提取的方法——附实例

  1. 模型的保存会覆盖,后一次保存的模型会覆盖上一次保存的模型。最多保存近5次结果。
  2. 应当保存效果最优时候的模型,而不是训练最后一次的模型。所以应该在每次进行模型性能评估后与保存的目前最后效果比较,如果性能更好则进行模型的保存。
  3. 模型的复用,当你想用别的性能评估指标的时候,不需要再次训练模型来获得指标值,可以提取最优模型直接计算新指标的值。
sess=tf.InteractiveSession()  
sess.run(tf.global_variables_initializer())

is_train=False
saver=tf.train.Saver(max_to_keep=3)

#训练阶段
if is_train:
    max_acc=0
    f=open('ckpt/acc.txt','w')
    for i in range(100):
      batch_xs, batch_ys = mnist.train.next_batch(100)
      sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
      val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
      print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
      f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
      if val_acc>max_acc:
          max_acc=val_acc
          saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
    f.close()
 
#验证阶段
else:
    model_file=tf.train.latest_checkpoint('ckpt/')
    saver.restore(sess,model_file)
    val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
    print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

实操:

说明:

  1. Social Attentional Memory Network 是一个推荐系统的模型,代码中没有模型保存和提取操作,数据量也算是小的,可以下载下来练习一下如何实际操作。
  2. SAMN 是我用这个模型进行的练习,可以参考,代码后面标注 lly 的是我写的或者修改的内容。

步骤:

  1. 先在原代码的主目录的下面建一个文件夹 model 。
  2. 第一次进行训练,进入目录执行 python SAMN.py ,其中参数 is_train = True
    训练完后发现model文件夹下面多了五个模型,最后一次保存的模型为最后模型,出现在第171次迭代的时候,即epoch=170
    在这里插入图片描述
    然后在控制台可以看到,epoch=170时候的评估结果:
    迭代第 166 次的损失为:26.586210:
    迭代第 167 次的损失为:26.567725:
    迭代第 168 次的损失为:26.586499:
    迭代第 169 次的损失为:26.571110:
    迭代第 170 次的损失为:26.668282:
    recall--------------------------------------------------------------------------------
    0.16846666666666665 0.19796666666666665 0.22703333333333334 0.24936666666666668 0.2713666666666667
    ndcg----------------------------------------------------------------------------------
    0.103169807535364 0.11131981364691529 0.11824016391770284 0.12317271387061263 0.12777428228959994
    save epoch  170
  1. 第二次使用保存好的模型,先将 SAMN.py 文件的参数 is_train 改为 False,再执行文件。
    执行完后可以看到控制台输出的评估结果和之前训练的时候的结果一样,证明操作成功。(最优结果我只保留了k=[10, 20, 50]的情况)
    在这里插入图片描述
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值