TensorLayer保存和读取模型tf.files.save_npz和tf.files.load

读取

tensorlayer.files.save_npz(save_list=[], name='model.npz', sess=None)

其中save_list为所要保存的参数,name为路径和保存的文件名,传入一个sess来执行此次操作。 


保存

tensorlayer.files.load_and_assign_npz(sess=None, name=None, network=None)

其中sess和name意义与保存(save_npz)相同,但是要注意的是network应该传入一个Layer类,然而如果只是简单初始化一个Layer类的变量传进去,运行立刻就会报错。

原因在于TensorLayer(也可以说是TensorFlow)所谓的保存,只是保存模型的参数和变量的值,而不是模型本身。这一点和sklearn中的模型保存是有区别的。 
因此只有当读取时的Layer(即模型)和保存时的模型结构上一模一样,才可以将保存的模型参数一一对应,从而对模型赋值。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值