load()是python文件操作的函数_keras读取h5文件load_weights、load代码操作

本文详细介绍了Python Keras中加载H5模型和权重的操作,包括`load_model()`和`load_weights()`的区别及使用示例,以及如何通过`h5py`库读取权重数据。通过解析`resnet50`模型的H5权重文件,展示了从HDF5文件中加载权重的步骤,包括获取层名、检查权重数据和将权重写入模型的过程。
摘要由CSDN通过智能技术生成

关于保存h5模型、权重网上的示例非常多,也非常简单。主要有以下两个函数:

1、keras.models.load_model() 读取网络、权重

2、keras.models.load_weights() 仅读取权重

load_model代码包含load_weights的代码,区别在于load_weights时需要先有网络、并且load_weights需要将权重数据写入到对应网络层的tensor中。

下面以resnet50加载h5权重为例,示例代码如下

import keras

from keras.preprocessing import image

import numpy as np

from network.resnet50 import resnet50

#修改过,不加载权重(默认官方加载亦可)

model = resnet50()

# 参数默认 by_name = fasle, 否则只读取匹配的权重

# 这里h5的层和权重文件中层名是对应的(除input层)

model.load_weights(r'\models\resnet50_weights_tf_dim_ordering_tf_kernels_v2.h5')

模型通过 model.summary()输出

一、模型加载权重 load_weights()

def load_weights(self, filepath, by_name=false, skip_mismatch=false, reshape=false):

if h5py is none:

raise importerror('`load_weights` requires h5py.')

with h5py.file(filepath, mode='r') as f:

if 'layer_names' not in f.attrs and 'model_weights' in f:

f = f['model_weights']

if by_name:

saving.load_weights_from_hdf5_group_by_name(

f, self.layers, skip_mismatch=skip_mismatch,reshape=reshape)

else:

saving.load_weights_from_hdf5_group(f, self.layers, reshape=reshape)

这里关心函数saving.load_weights_from_hdf5_group(f, self.layers, reshape=reshape)即可,参数 f 传递了一个h5py文件对象。

读取h5文件使用 h5py 包,简单使用hdfview看一下resnet50的权重文件。

resnet50_v2 这个权重文件,仅一个attr “layer_names”, 该attr包含177个string的array,array中每个元素就是层的名字(这里是严格对应在keras进行保存权重时网络中每一层的name值,且层的顺序也严格对应)。

对于每一个key(层名),都有一个属性"weights_names",(value值可能为空)。

例如:

conv1的"weights_names"有"conv1_w:0"和"conv1_b:0",

flatten_1的"weights_names"为null。

这里就简单介绍,后面在代码中说明h5py如何读取权重数据。

二、从hdf5文件中加载权重 load_weights_from_hdf5_group()

1、找出keras模型层中具有weight的tensor(tf.variable)的层

def load_weights_from_hdf5_group(f, layers, reshape=false):

# keras模型resnet50的model.layers的过滤

# 仅保留layer.weights不为空的层,过滤掉无学习参数的层

filtered_layers = []

for layer in layers:

weights = layer.weights

if weights:

filtered_layers.append(layer)

filtered_layers为当前模型resnet50过滤(input、paddind、activation、merge/add、flastten等)层后剩下107层的list

2、从hdf5文件中获取包含权重数据的层的名字

前面通过hdfview看过每一层有一个[“weight_names”]属性,如果不为空,就说明该层存在权重数据。

先看一下控制台对h5py对象f的基本操作(需要的去查看相关数据结构定义):

>>> f

>>> f.filename

'e:\\deeplearning\\keras_test\\models\\resnet50_weights_tf_dim_ordering_tf_kernels_v2.h5'

>>> f.name

'/'

>>> f.attrs.keys() # f属性列表 #

>>> f.k

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值