关于保存h5模型、权重网上的示例非常多,也非常简单。主要有以下两个函数:
1、keras.models.load_model() 读取网络、权重
2、keras.models.load_weights() 仅读取权重
load_model代码包含load_weights的代码,区别在于load_weights时需要先有网络、并且load_weights需要将权重数据写入到对应网络层的tensor中。
下面以resnet50加载h5权重为例,示例代码如下
import keras
from keras.preprocessing import image
import numpy as np
from network.resnet50 import resnet50
#修改过,不加载权重(默认官方加载亦可)
model = resnet50()
# 参数默认 by_name = fasle, 否则只读取匹配的权重
# 这里h5的层和权重文件中层名是对应的(除input层)
model.load_weights(r'\models\resnet50_weights_tf_dim_ordering_tf_kernels_v2.h5')
模型通过 model.summary()输出
一、模型加载权重 load_weights()
def load_weights(self, filepath, by_name=false, skip_mismatch=false, reshape=false):
if h5py is none:
raise importerror('`load_weights` requires h5py.')
with h5py.file(filepath, mode='r') as f:
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
if by_name:
saving.load_weights_from_hdf5_group_by_name(
f, self.layers, skip_mismatch=skip_mismatch,reshape=reshape)
else:
saving.load_weights_from_hdf5_group(f, self.layers, reshape=reshape)
这里关心函数saving.load_weights_from_hdf5_group(f, self.layers, reshape=reshape)即可,参数 f 传递了一个h5py文件对象。
读取h5文件使用 h5py 包,简单使用hdfview看一下resnet50的权重文件。
resnet50_v2 这个权重文件,仅一个attr “layer_names”, 该attr包含177个string的array,array中每个元素就是层的名字(这里是严格对应在keras进行保存权重时网络中每一层的name值,且层的顺序也严格对应)。
对于每一个key(层名),都有一个属性"weights_names",(value值可能为空)。
例如:
conv1的"weights_names"有"conv1_w:0"和"conv1_b:0",
flatten_1的"weights_names"为null。
这里就简单介绍,后面在代码中说明h5py如何读取权重数据。
二、从hdf5文件中加载权重 load_weights_from_hdf5_group()
1、找出keras模型层中具有weight的tensor(tf.variable)的层
def load_weights_from_hdf5_group(f, layers, reshape=false):
# keras模型resnet50的model.layers的过滤
# 仅保留layer.weights不为空的层,过滤掉无学习参数的层
filtered_layers = []
for layer in layers:
weights = layer.weights
if weights:
filtered_layers.append(layer)
filtered_layers为当前模型resnet50过滤(input、paddind、activation、merge/add、flastten等)层后剩下107层的list
2、从hdf5文件中获取包含权重数据的层的名字
前面通过hdfview看过每一层有一个[“weight_names”]属性,如果不为空,就说明该层存在权重数据。
先看一下控制台对h5py对象f的基本操作(需要的去查看相关数据结构定义):
>>> f
>>> f.filename
'e:\\deeplearning\\keras_test\\models\\resnet50_weights_tf_dim_ordering_tf_kernels_v2.h5'
>>> f.name
'/'
>>> f.attrs.keys() # f属性列表 #
>>> f.k