DeepLearning tutorial(2)机器学习算法在训练过程中保存参数

本文介绍了在深度学习训练过程中如何保存和加载模型参数,以防止因意外中断而丢失进度。通过Python的cPickle模块,可以方便地进行参数的序列化和反序列化。在训练循环中,当模型性能提升时,可以保存模型参数。同时,通过修改逻辑回归的代码,展示了如何在模型初始化时加载保存的参数,从而避免从头开始训练。
摘要由CSDN通过智能技术生成

DeepLearning tutorial(2)机器学习算法在训练过程中保存参数


@author:wepon

@blog:http://blog.csdn.net/u012162613/article/details/43169019


参考:pickle — Python object serializationDeepLearning Getting started


一、python读取"***.pkl.gz"文件


用到python里的gzip以及cPickle模块,简单的使用代码如下,如果想详细了解可以参考上面给出的链接。


   
   
  1. #以读取mnist.pkl.gz为例
  2. import cPickle, gzip
  3. f = gzip.open( 'mnist.pkl.gz', 'rb')
  4. train_set, valid_set, test_set = cPickle.load(f)
  5. f.close()

其实就是分两步,先读取gz文件,再读取pkl文件。pkl文件的应用正是下文要讲的,我们用它来保存机器学习算法训练过程中的参数。


二、机器学习算法在训练过程中如何保存参数?


我们知道,机器学习算法的计算量特别大,跑起程序来少则几十分钟,多则几小时甚至几天,中间如果有什么状况(比如电脑过热重启、程序出现一些小bug...)程序就会中断,如果你没把参数定时保存下来,前面的训练就当白费了,所以很有必要在程序中加入定时保存参数的功能,这样下次训练就可以将参数初始化为上次保存下来的结果,而不是从头开始随机初始化。

那么如何保存模型参数?可以将参数深复制,或者调用python的数据永久存储cPickle模块,原理不多说,直接使用就行。(注:python里有cPickle和pickle,cPickle基于c实现,比pickle快。)

直接用一个例子来说明如何使用:


    
    
  1. a=[ 1, 2, 3]
  2. b={ 4: 5, 6: 7}
  3. #保存,cPickle.dump函数。/home/wepon/ab是路径,ab是保存的文件的名字,如果/home/wepon/下本来就有ab这个文件,将被覆写#,如果没有,则创建。'wb'表示以二进制可写的方式打开。dump中的-1表示使用highest protocol。
  4. import cPickle
  5. write_file=open( '/home/wepon/ab', 'wb')
  6. cPickle.dump(a,write_file, -1)
  7. cPickle.dump(b,write_file, -1)
  8. write_file.close()
  9. #读取,cPickle.load函数。
  10. read_file=open( '/home/wepon/ab', 'rb')
  11. a_1=cPickle.load(read_file)
  12. b_1=cPickle.load(read_file)
  13. print a,b
  14. read_file.close()


在deeplearning算法中,因为用到GPU,经常是将参数声明为shared变量,因此必须用上get_value()、set_value,例如有w、v、u三个shared变量,使用代码如下:


    
    
  1. import cPickle
  2. #保存
  3. write_file = open( 'path', 'wb')
  4. cPickle.dump(w.get_value(borrow= True), write_file, -1)
  5. cPickle.dump(v.get_value(borrow= True), write_file, -1)
  6. cPickle.dump(u.get_value(borrow= True), write_file, -1)
  7. write_file.close()
  8. #读取
  9. read_file = open( 'path')
  10. w.set_value(cPickle.load(read_file), borrow= True)
  11. v.set_value(cPickle.load(read_file), borrow= True)
  12. u.set_value(cPickle.load(read_file), borrow= True)
  13. read_file.close()



一个实例

下面我以一个实际的例子来说明如何在程序中加入保存参数的功能。以deeplearnig.net上的逻辑回归为例,它的代码地址:logistic_sgd.py。这个程序是将逻辑回归用于MNIST分类,程序运行过程并不会保存参数,甚至运行结束时也不保存参数。怎么做可以保存参数?

在logistic_sgd.py代码里最后面的sgd_optimization_mnist()函数里,有个while循环,里面有一句代码:

if this_validation_loss < best_validation_loss:
    
    

这句代码的意思就是判断当前的验证损失是否小于最佳的验证损失,是的话,下面会更新best_validation_loss,也就是说当前参数下,模型比之前的有了优化,因此我们可以在这个if语句后面加入保存参数的代码:


save_params(classifier.W,classifier.b)
    
    


save_params函数定义如下:



    
    
  1. def save_params(param1,param2):
  2. import cPickle
  3. write_file = open( 'params', 'wb')
  4. cPickle.dump(param1.get_value(borrow= True), write_file, -1)
  5. cPickle.dump(param2.get_value(borrow= True), write_file, -1)
  6. write_file.close()


当然参数的个数根据需要去定义。在logistic_sgd.py中参数只有classifier.W,classifier.b,因此这里定义为save_params(param1,param2)。



在logistic_sgd.py里我加入了save_params(classifier.W,classifier.b),运行了3次epoch,中断掉程序,在代码所在的文件夹下,多出了一个params文件,我们来看看这个文件里是什么东西:

    
    
  1. import cPickle
  2. f=open( 'params')
  3. w=cPickle.load(f)
  4. b=cPickle.load(f)
  5. #w大小是(n_in,n_out),b大小时(n_out,),b的值如下,因为MINST有10个类别,n_out=10,下面正是10个数
  6. array([ -0.0888151 , 0.16875755, -0.03238435, -0.06493175, 0.05245609,
  7. 0.1754718 , -0.0155049 , 0.11216578, -0.26740651, -0.03980861])

也就是说,params文件确实保存了我们训练过程中的参数。


那么如何用保存下来的参数来初始化我们的模型的参数呢?

在logistic_sgd.py中的class LogisticRegression(object)下,self.W和self.b本来是初始化为0的,我们可以在下面加上几行代码,这样就可以用我们保存下来的params文件来初始化参数了:


    
    
  1. class LogisticRegression(object):
  2. def __init__(self, input, n_in, n_out):
  3. self.W = theano.shared(
  4. value=numpy.zeros(
  5. (n_in, n_out),
  6. dtype=theano.config.floatX
  7. ),
  8. name= 'W',
  9. borrow= True
  10. )
  11. self.b = theano.shared(
  12. value=numpy.zeros(
  13. (n_out,),
  14. dtype=theano.config.floatX
  15. ),
  16. name= 'b',
  17. borrow= True
  18. )
  19. #!!!
  20. #加入的代码在这里,程序运行到这里将会判断当前路径下有没有params文件,有的话就拿来初始化W和b
  21. if os.path.exists( 'params'):
  22. f=open( 'params')
  23. self.W.set_value(cPickle.load(f), borrow= True)
  24. self.b.set_value(cPickle.load(f), borrow= True)


  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值