活动地址:CSDN21天学习挑战赛
在上篇博客中对LetNet模型进行了介绍:理论篇1:深度学习之----LetNet模型详解;
今天我们我们通过tensorflow2,实现LetNet模型的编写。
环境的配置:
1、环境
- 语言:python,版本3.7.0;(python安装教程网上有很多教程都很好,可自行查询安装)
- 编译器:pycharm;
- 深度学习框架tensorflow2;
2、配置方法
python和pycharmde安装教程网上有很多教程都很好,可自行查询安装;
下面主要介绍tensorflow2的安装方法:
- 方法一、通过pip在线安装
pip install tensorflow
注:使用这种方式安装,有时候会因为网络原因造成安装失败;这种情况下尝试第二种安装方式。
- 方法二、下载whl文件,本地安装
下载whl文件的地址:https://pypi.org/search/?q=tensorflow
// cd 到.whl文件的文件夹
pip install tensorflow-2.9.1-cp37-cp37m-win_amd64.whl
程序的编写:
以上环境配置完成后就可以进行程序的看编写,接下来将进行LetNet模型的编写:
LetNet.py
#构建网络
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)), # 卷积层1,卷积核3*3
layers.MaxPooling2D((2, 2)), # 池化层1,2*2采样
layers.Conv2D(64, (3, 3), activation='relu'), # 卷积层2,卷积核3*3
layers.MaxPooling2D((2, 2)), # 池化层2,2*2采样
layers.Flatten(), # Flatten层,连接卷积层与全连接层
layers.Dense(64, activation='relu'), # 全连接层,特征进一步提取
layers.Dense(10) # 输出层,输出预期结果
])
# 打印网络结构
model.summary()
网络详情:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
conv2d (Conv2D) (None, 26, 26, 32) 320
_________________________________________________________________
max_pooling2d (MaxPooling2D) (None, 13, 13, 32) 0
_________________________________________________________________
conv2d_1 (Conv2D) (None, 11, 11, 64) 18496
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (None, 5, 5, 64) 0
_________________________________________________________________
flatten (Flatten) (None, 1600) 0
_________________________________________________________________
dense (Dense) (None, 64) 102464
_________________________________________________________________
dense_1 (Dense) (None, 10) 650
=================================================================
Total params: 121,930
Trainable params: 121,930
Non-trainable params: 0
_________________________________________________________________
知识点讲解:
1、model.summary() //打印网络
2、flatten() //此函数作用就是把所有的参数转换成一维的
如有错误,欢迎大家指正!