前言
本文是为了实现存储自己训练好的模型 结构和参数,以及加载训练好的模型进行预测。
代码
保存
def save(self,filename):
"""
模型保存
:param filename: 文件名
:return:
"""
data ={ "sizes": self.sizes, #模型结构
"weights": [w.tolist() for w in self.weights], #tolist转换为列表类型
"biases": [b.tolist() for b in self.biases],
"cost": str(self.cost.__name__) #保存一下损失函数
}
f=open(filename,"w")
json.dump(data,f)
f.close()
加载
def load(filename):
"""
加载模型
:param filename:
:return:
"""
f=open(filename,"r")
data=json.load(f)
f.close()
cost=getattr(sys.modules[__name__],data["cost"]) #找对象
net=Network(data["sizes"],cost=cost)
net.weights=[np.array(w) for w in data["weights"]]
net.biases=[np.array(b) for b in data["biases"]]
return net