LeNet5网络模型

7 篇文章 0 订阅

结构模型

论文:http://yann.lecun.com/exdb/publis/pdf/lecun-01a.pdf

这个可以说是CNN的开山之作,由Yann LeCun在1998年提出,可以实现对手写数字、字母的识别。结构如下

代码实现

使用tensorflow实现结构, 数据集mnist, 用keras库自带的读取

import tensorflow as tf
from keras.datasets.mnist import load_data

import numpy as np


def LeNet5(x):
    out = tf.reshape(x, [-1, 28, 28, 1])
    out = tf.layers.Conv2D(6, 5, 1, padding="same")(out)
    out = tf.layers.BatchNormalization()(out)
    out = tf.nn.relu(out)

    out = tf.layers.AveragePooling2D(2, 2, padding="same")(out)

    out = tf.layers.Conv2D(16, 5, 1)(out)
    out = tf.layers.BatchNormalization()(out)
    out = tf.nn.relu(out)

    out = tf.layers.AveragePooling2D(2, 2, padding="same")(out)

    out = tf.reshape(out, [-1, 5*5*16])
    out = tf.layers.Dense(120, activation=tf.nn.relu)(out)
    out = tf.layers.Dense(84, tf.nn.relu)(out)
    out = tf.layers.Dense(10)(out)

    return out


def mnist_generator(trainx, trainy, batch_size):
    for i in range(int(len(trainy) / batch_size)):
        yield trainx[i*batch_size: (i+1)*batch_size], trainy[i*batch_size: (i+1)*batch_size]


def train():
    x = tf.placeholder("float", [None, 28, 28])
    y = tf.placeholder(tf.int64, [None])
    (trainx, trainy), (testx, testy) = load_data()

    out = LeNet5(x)
    loss = tf.losses.sparse_softmax_cross_entropy(labels=y, logits=out)
    optimizer = tf.train.AdamOptimizer().minimize(loss)

    acc = tf.count_nonzero(tf.equal(tf.argmax(out, 1), y))

    batch_size = 32
    total_batch = int(len(trainy) / batch_size)
    epochs = 10

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())
        for epoch in range(epochs):
            datagen = mnist_generator(trainx, trainy, batch_size)
            total_loss = 0.
            for bx, by in datagen:
                _, l = sess.run([optimizer, loss], feed_dict={x: bx, y: by})
                total_loss += l/batch_size
            print("total_loss: ", total_loss, " accuracy: ", sess.run(
                acc, feed_dict={x: testx, y: testy}) / len(testy))


if __name__ == "__main__":
    train()

 

  • 1
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值