tensorflow实现模型可视化(plot_model函数)

###    过程成遇到的错误

Failed to import pydot. You must install pydot and graphviz for `pydotprint` to work.

(意思要安装pydot和graphiviz模块,安装过程会有许多意想不到的错误)

 

####   下面是安装成功的环境配置和步骤(建议使用Anconda+conda命令符安装)

python:3.6.2

keras:2.1.0

tensorflow:1.2.1

pydot-ng:2.0.0

graphiviz:2.38.0

## 步骤:

首先第一步:pip install graphviz

第二步:安装graphviz

第三步:pip install pydot-ng

测试代码:


from keras.models import Sequential
from keras.layers import Conv2D , MaxPool2D , Flatten , Dense , Dropout
from keras.optimizers import Adam
from keras.utils.vis_utils import plot_model
import matplotlib.pyplot as plt

'''cnn模型的建立 begin'''
model=Sequential()
#Relu函数:通常指代以斜坡函数及其变种为代表非非线性函数
#cnn中激活函数:在神经网络中,为避免单纯的线性组合,我们在每一层的输出后面都添加一个激活函数作为非线性因素
'''
conv2D参数介绍:
当使用该层作为模型第一层时,需要提供input_shape参数
filters:整数,输出空间的维度
kernel_size:一个整数或两个整数表示的元组或列表,指明卷积窗口的宽度和高度
padding: "valid" 或 "same" (大小写敏感)。   valid padding就是不padding,而same padding就是指padding完尺寸与原来相同
图像识别一般来说都要padding,尤其是在图片边缘的特征重要的情况下。padding多少取决于我们需要的输出是多少
'''
model.add(Conv2D(input_shape=(150,150,3),filters=32,kernel_size=3,padding='same',activation='relu'))
model.add(Conv2D(filters=32,kernel_size=3,padding='same',activation='relu'))
model.add(MaxPool2D(pool_size=2,strides=2))

model.add(Conv2D(filters=64,kernel_size=3,padding='same',activation='relu'))
model.add(Conv2D(filters=64,kernel_size=3,padding='same',activation='relu'))
model.add(MaxPool2D(pool_size=2,strides=2))

model.add(Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
model.add(Conv2D(filters=128,kernel_size=3,padding='same',activation='relu'))
model.add(MaxPool2D(pool_size=2,strides=2))

model.add(Flatten())
model.add(Dense(units=64,activation='relu'))
'''
Dropout 是在训练圣经网络时,样本数据过少,防止过拟合而采用的
'''
model.add(Dropout(0.5))
model.add(Dense(2,activation='softmax'))

#定义优化器
adam=Adam(lr=1e-4)
#定义优化器,代价函数,训练过程中的准确率
model.compile(optimizer=adam,loss='categorical_crossentropy',metrics=['accuracy'])

#模型网络结构图输出
plot_model(model,to_file='model1.png',show_shapes=True,show_layer_names=True,rankdir='TB')
plt.figure(figsize=(10,10))
img=plt.imread('model1.png')
plt.imshow(img)
plt.axis('off')
plt.show()

结果图片展示:

 

已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页