Python深度学习13——Keras模型的可视化(神经网络结构图)

报错问题 

Failed to import pydot. You must `pip install pydot` and install graphviz


我们在调用keras里面的高级API——plot_model(),去画神经网络的结构图的时候可能会遇到两个报错问题。

第一个是说keras.utils里面不存在plot_model()这个用法。

cannot import name 'plot_model' from 'keras.utils'

这个问题好解决,因为keras里面确实没有plot_model()用法,但是他的好兄弟——TensorFlow里面有.....

直接这样导入:

from tensorflow.keras.utils import plot_model

就可以了。

第二个报错问题是:

Failed to import pydot. You must `pip install pydot` and install graphviz

意思是缺失两个包,一个pydot,一个graphviz。

我查了很多文章,很多方法比较麻烦,都需要手动下载,手动配置环境变量,后来看到一个很简单的方法,并且也测试有效。

直接在anaconda prompt里面:

conda install graphviz
conda install pydotplus

就可以了,不过安装过程好像会给你装上很多额外的包.....不过不影响环境,神经网络还是一样能跑。


画图测试

导入包,构建一个网络。这个网络是Model类,采用函数API实现,稍微复杂点,可以看图就会清楚他的结构。

导入包

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import keras
from keras.preprocessing import sequence
from keras.models import Sequential,Model
from keras.layers import Dense,Input, Dropout, Embedding, Flatten,MaxPooling1D,Conv1D,SimpleRNN,LSTM,GRU,Multiply
from keras.layers import Bidirectional,Activation,BatchNormalization
from keras.layers.merge import concatenate

from keras.callbacks import EarlyStopping
from tensorflow.keras import regularizers
from keras.utils.np_utils import to_categorical
from tensorflow.keras  import optimizers
from tensorflow.keras.utils import plot_model

定义模型:

inputs = Input(name='inputs',shape=[64,100], dtype='float64')
gru=Bidirectional(GRU(32,return_sequences=True,))(inputs)
mlp = Dense(64,activation='relu')(gru)
attention_probs = Dense(64, activation='softmax', name='attention_vec')(mlp)
attention_mul =  Multiply()([mlp, attention_probs])
mlp = Dense(64)(attention_mul) #原始的全连接
fla=Flatten()(mlp)
output = Dense(2, activation='softmax')(fla)
model = Model(inputs=[inputs], outputs=output)
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])

画出图形:

plot_model(model,'new_model.png',show_shapes=True)

第一参数是神经网络模型,第二个参数是储存的图片名称,第三个是在图片上打印出每层的数据形状。

 用这种图就能很方便的展示组建的模型的架构,多输入多输出都行。

show_shapes=True参数改为False,就可以简化展示图片,不打印形状。我这里换了一种结构的网络。

plot_model(model,'model2.png',show_shapes=False)

 

  • 11
    点赞
  • 47
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

阡之尘埃

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值