kares中如何从已保存的训练模型中了解模型的具体结构

4 篇文章 0 订阅
2 篇文章 0 订阅

文章目录

1. 问题由来

1.1 问题描述

在进行机器学习工程改造时,可能存在一种情况,那就是该项目中调用的是一个已经训练好的机器学习模型,且该模型未在文件中定义。

要想改造工程,我们首先需要实现每个机器学习网络,因此掌握已有模型结构是必不可少的。

然而在目前的一些方法中,我们只能有限的观察模型的结构,大部分的方法都忽略了每一层的激活函数。

通过一番查阅资料和捣鼓,最终解决了该问题。

如果你想整体了解都有哪些方法,请耐心读完,但倘若你只是想快速解决问题,那么请直接跳转目录2.3

1.2 问题假设

为了便于描述,在这里假设模型的网络结构如下:

self.h1 = tf.keras.layers.Dense(100, activation='relu')
self.h2 = tf.keras.layers.Dense(64, activation='relu')
self.h3 = tf.keras.layers.Dense(32, activation='sigmoid')
self.h4 = tf.keras.layers.Dense(16, activation='sigmoid')
self.h5 = tf.keras.layers.Dense(8, activation='sigmoid')
self.out = tf.keras.layers.Dense(1)
self.model = tf.keras.models.Sequential([self.h1, self.h2, self.h3, self.h4, self.h5, self.out])

上面只是描述了模型的整体真实结构,便于后面问题的描述,并不是一个可运行的程序,望周知!

在假设的模型中一共有五层全连接网络,四个激活函数。在训练完网络时,我们会使用model.save()操作保存训练好的模型,调用时直接调用训练好的模型,因此当在项目中只有训练好的模型,比如model.h文件,而没有具体模型的描述,我们要想改进代码,则需要知晓原来的网络模型。

下面将会描述三种方法获取模型结构,使用前两种方法,并不能完全掌握模型的结构,因此我查询资料后补充了第三种,如果有任何不妥,欢迎指出。

2. 问题解决

2.1 模型简单描述方法model.summary()

使用模型摘要,直接调用已有的方法,这是最快捷直接的,通过该方法,我们能够得到模型的以下内容:

  1. 网络层名称及对应网络层类型
  2. 各层的输出形状
  3. 各层的参数数目

在假设的网络模型中,调用model.summary()方法的输出如下:

Model: "sequential"
____________________________________________________________________________
 Layer (type)                Output Shape              Param #   Trainable  
============================================================================
 dense (Dense)               (None, 100)               1000      Y          
                                                                            
 dense_1 (Dense)             (None, 64)                6464      Y          
                                                                            
 dense_2 (Dense)             (None, 32)                2080      Y          
                                                                            
 dense_3 (Dense)             (None, 16)                528       Y          
                                                                            
 dense_4 (Dense)             (None, 8)                 136       Y          
                                                                            
 dense_5 (Dense)             (None, 1)                 9         Y          
                                                                            
============================================================================
Total params: 10,217
Trainable params: 10,217
Non-trainable params: 0
____________________________________________________________________________

缺点:

  1. 无法掌握输入特征数(但也可以通过debug查询)
  2. 无法掌握各层对应的激活函数

2.2 通过plot_model绘制模型图像

借助plot_model方法绘制模型的结构图,具体流程如下:

  1. 使用pip安装pydot

    pip insall pydot -i https://pypi.douban.com/simple
    
  2. 安装graphviz(重点)
    graphviz包可以通过pip进行安装,但若是通过pip安装了,建议删除,否则可能会报错。该包需要通过官网下载安装文件进行安装,并配置环境变量,官网地址如下:
    graphviz官方下载地址

  3. 导入plot_model方法并调用方法绘制图像
    导入plot_model包。

    from keras.utils import plot_model
    

    调用plot_model方法进行模型图像绘制。

    plot_model(model=model, to_file='model.png', show_shapes=True)
    

最终,绘制图像如下所示:
在这里插入图片描述

优缺点

  • 优点:可以查看每层网络的输入
  • 缺点:无法了解模型中各层的激活函数

2.3 借助model.to_json文件还原模型

以上方法都没办法让我们完整的了解模型的具体结构,因此,在查询相关资料后发现,通过model.to_json获取模型描述的json文件,从该文件中可以获取模型中所需要的各类结构描述及参数。使用方法如下:

# 1.获取模型的json文件
model_json = model.to_json()
# 2.打印json内容
print(model_json)

输出如下:

{"class_name": "Sequential", "config": {"name": "sequential", "layers": [{"class_name": "InputLayer", "config": {"batch_input_shape": [null, 9], "dtype": "float64", "sparse": false, "ragged": false, "name": "dense_input"}}, {"class_name": "Dense", "config": {"name": "dense", "trainable": true, "dtype": "float32", "units": 100, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "dense_1", "trainable": true, "dtype": "float32", "units": 64, "activation": "relu", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "dense_2", "trainable": true, "dtype": "float32", "units": 32, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "dense_3", "trainable": true, "dtype": "float32", "units": 16, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "dense_4", "trainable": true, "dtype": "float32", "units": 8, "activation": "sigmoid", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}, {"class_name": "Dense", "config": {"name": "dense_5", "trainable": true, "dtype": "float32", "units": 1, "activation": "linear", "use_bias": true, "kernel_initializer": {"class_name": "GlorotUniform", "config": {"seed": null}}, "bias_initializer": {"class_name": "Zeros", "config": {}}, "kernel_regularizer": null, "bias_regularizer": null, "activity_regularizer": null, "kernel_constraint": null, "bias_constraint": null}}]}, "keras_version": "2.9.0", "backend": "tensorflow"}

从上面的json文件中我们可以查看模型的所有参数,但是我们还原模型只需要各层结构及激活函数,因此,直接查看json未免太废眼。

因此,我编写了一个函数model_josn_read,通过调用该方法直接可以像调用model.summary()一样直观展示模型的结构。调用方法如下:

model_josn_read(model_json=model_json)

输出结果模式如下:

==========================================================================================
Model Name: sequential
------------------------------------------------------------------------------------------
Layer Name(type)              Layer shape or units          Activation                    
==========================================================================================
dense_input(InputLayer)       [None, 9]                     
dense(Dense)                  100                           relu                          
dense_1(Dense)                64                            relu                          
dense_2(Dense)                32                            sigmoid                       
dense_3(Dense)                16                            sigmoid                       
dense_4(Dense)                8                             sigmoid                       
dense_5(Dense)                1                             linear                        
==========================================================================================

通过该方法,我们可以直接的了解每一层的网络结构及激活函数,同时输入时的特征也使用InputLayer输出显示出来,然后根据神经网络输入输出层规律,便可计算出来。

完整model_josn_read代码如下:

def model_josn_read(model_json):
    # 将json字符串转化为字典前的必要操作
    model_json = model_json.replace("null", "None")
    model_json = model_json.replace("false", "False")
    model_json = model_json.replace("true", "True")
    model_json = eval(model_json)

    model_json = model_json["config"]

    print("=" * 90)
    print('Model Name: %s' % model_json['name'])
    print('-' * 90)
    width = 30
    print('{: <{}}{: <{}}{: <{}}'.format('Layer Name(type)', width, 'Layer shape or units', width, 'Activation', width))
    print("=" * 90)
    model_json = model_json['layers']

    for layer in model_json:
        if layer['class_name'] == 'InputLayer':
            layer = layer['config']
            inputSize = layer['batch_input_shape']

            print('{: <{}}{: <{}}'.format(layer['name']+'(InputLayer)', width, str(inputSize), width))
        else:
            print('{: <{}}{: <{}}{: <{}}'.format(layer['config']['name']+'('+layer['class_name']+')', width,
                                                 str(layer['config']['units']), width, layer['config']['activation'], width))
    print("=" * 90)

优缺点

通过模型的json文件可以完整了解模型的结构并完成模型复现,能够较好解决当前的问题。

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值