LeNet源码(基于tensorflow2.0)

手写字体识别模型LeNet5诞生于1994年,是最早的卷积神经网络之一。LeNet5通过巧妙的设计,利用卷积、参数共享、池化等操作提取特征,避免了大量的计算成本,最后再使用全连接神经网络进行分类识别,这个网络也是最近大量神经网络架构的起点。
LeNet架构如下:
在这里插入图片描述

LeNet_5

import cv2,os
import numpy as np
import tensorflow as tf
from random import shuffle
from tensorflow.keras.models import load_model
from tensorflow.keras import Sequential,layers,optimizers,losses,metrics

labels_num=2 # 类别数

#测试集的导入
def load_image(path,shape):
    img_list = []
    label_list = []
    dir_counter = 0
    # 对路径下的所有子文件夹中的所有jpg文件进行读取并存入到一个list中
    for child_dir in os.listdir(path):
        child_path = os.path.join(path, child_dir)
        for dir_image in os.listdir(child_path):
            img = cv2.imread(os.path.join(child_path, dir_image))
            img = img / 255.0
            img=cv2.resize(img,(shape[0],shape[1]))
            img_list.append(img)
            label_list.append(dir_counter)
        dir_counter += 1

    length= len(img_list)
    index = [i for i in range(length)]
    shuffle(index)  # 打乱索引
    img_np=np.array(img_list)
    label_np=np.array(label_list)
    img_np1 = img_np[index]
    label_np1 = label_np[index]
    train_l=int(0.7*length)

    train_data = np.array(img_np1)[0:train_l]
    train_label =np.array(label_np1)[0:train_l]
    test_data = np.array(img_np1)[train_l:length]
    test_label = np.array(label_np1)[train_l:length]
    return train_data,train_label,test_data,test_label
def model(label_num=labels_num):
    #网络层的搭建
    networks=Sequential([
        layers.Conv2D(6,kernel_size=3,strides=1,activation='relu'),
        layers.MaxPooling2D(pool_size=2,strides=2),
        layers.Conv2D(16,kernel_size=3,strides=1,activation='relu'),
        layers.MaxPooling2D(pool_size=2,strides=2),
        layers.Flatten(),
        layers.Dense(120,activation='relu'),
        layers.Dense(84,activation='relu'),
        layers.Dense(label_num) #输出层,没有激活函数(激活函数为None)
        ])
    return networks
def train(net,train_data,train_label):
    def get_batch(batch_size, i):
        x = batch_size * i
        train_data_batch = train_data[x:x + batch_size, :]
        train_lable_batch = train_label[x:x + batch_size]
        return train_data_batch, train_lable_batch

    epoch = 5  # 迭代次数
    batch_size = 32  # 一批要处理的图像
    shape_t=train_data.shape
    net.build(input_shape=(batch_size,shape_t[1],shape_t[2],shape_t[3]))
    num_train_data = shape_t[0]  # 训练图像总数
    batch_num = int(num_train_data // batch_size)  # 训练批数:这里必须取整
    optimizer = optimizers.Adam(learning_rate=0.001)  # 该函数可以设置一个随训练进行逐渐减小的学习率,此处我们简单的设学习率为常量
    for n in range(epoch):
        for i in range(batch_num):
            with tf.GradientTape() as tape:  # with语句内引出需要求导的量
                x, y = get_batch(batch_size, i)
                out = net(x)
                y_onehot = tf.one_hot(y, depth=labels_num)  # 一维表示类别(0-9)-> 二维表示类别(1,0,0,0,...)...
                loss_object = losses.CategoricalCrossentropy(from_logits=True)  # 交叉熵损失函数.这是一个类,loss_object为类的实例化对象
                loss = loss_object(y_onehot, out)  # 使用损失函数类来计算损失
                print('epoch:%d batch:%d loss:%f' % (n, i, loss.numpy()))
            grad = tape.gradient(loss, net.trainable_variables)  # 用以自动计算梯度. loss对网络中的所有参数计算梯度
            optimizer.apply_gradients(zip(grad, net.trainable_variables))  # 根据梯度更新网络参数
    net.save('model/lenet.h5')

def test(test_data,test_label):
    net=load_model('model/lenet.h5')
    batch_size=32
    s_c_a = metrics.SparseCategoricalAccuracy()  # metrics用于监测性能指标,这里用update_state来对比
    num_test_batch = int(test_data.shape[0] // batch_size)  # 测试集数量
    for num_index in range(num_test_batch):
        start_index, end_index = num_index * batch_size, (num_index + 1) * batch_size  # 每一批的起始索引和结束索引
        y_predict = net.predict(test_data[start_index:end_index])
        s_c_a.update_state(y_true=test_label[start_index:end_index], y_pred=y_predict)
    print('test accuracy:%f' % s_c_a.result())

if __name__ == '__main__':
    path = "E:/project_file/dataset/horse-or-human/valid"
    train_data,train_label,test_data,test_label=load_image(path,(244,244))
    net = model()
    train(net,train_data,train_label)
    print('------------------------------')
    test(test_data,test_label)

模型架构

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (32, 26, 26, 6)           60        
_________________________________________________________________
max_pooling2d (MaxPooling2D) (32, 13, 13, 6)           0         
_________________________________________________________________
conv2d_1 (Conv2D)            (32, 11, 11, 16)          880       
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 (32, 5, 5, 16)            0         
_________________________________________________________________
flatten (Flatten)            (32, 400)                 0         
_________________________________________________________________
dense (Dense)                (32, 120)                 48120     
_________________________________________________________________
dense_1 (Dense)              (32, 84)                  10164     
_________________________________________________________________
dense_2 (Dense)              (32, 10)                  850       
=================================================================
Total params: 60,074
Trainable params: 60,074
Non-trainable params: 0

测试精度

test accuracy:0.988281
  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
l_e多页面浏览器(0.6)源码 2010-9-18 在09年2月份的时候版本就变成0.6了,那时候就想着要把源码公布出来. 至于为何拖到现在才发布?一直没想到应该用怎样的版权声明. 到了现今,貌似也没必要想这个事情了,各位喜欢怎样用这份源码都行,如果可以的话,最好能提一下我,^0^ 另,库找不到的问题,可以先把所有的引用去掉,然后再依次添加. l_e多页面浏览器(1.4.136)源码 2007-2-12 其实此为0.5preview的版本,但有还没做到预期希望完成的功能,还差些吧,于是还叫1.4, 本来是暂时不想放出来的,想把未完成的都完成了再放出来,但最近也懒了,还是先把这个放出来吧, 修改的东西比较多,特别是插件部分的代码,详细还是看devlog.txt吧 l_e多页面浏览器(1.4.56)源码 2006-1-11 源码所作的更新可以在devlog.txt中看到 myacc是实现"监视所有下载项"的主要代码,用vc写 l_e多页面浏览器(1.4.0)源码 2005-8-30 =========== 目录 ============ 浏览器特点 使用到的技巧 各文件夹作用 其他 ==============浏览器特点============== 1.多页面浏览(呵呵,此为废话) 2.支持鼠标手势,并且可以自定义 3.支持页面拖拽,并且可以定义拖拽各方向的功能(类似GreenBrowser).拖拽开启时,页面中原本拖拽所实现的部分功能能正常使用,如将一段文字拖拽进一个textarea,input等,好像现在网上流行的多页面浏览器并不具备这个功能哦. 4.支持插件 5.广告过滤(效果大概没有现在那些热门多页面浏览器好,如maxthon) 6.页面规则,对地址满足一定条件的页面执行相应操作 7.自定义搜索引擎,可使用多引擎搜索 8.方便代理服务器切换 9.捕获下载地址,并使用指定的下载工具下载 10.rss阅读(利用插件) 11.绿色 12...... ==============使用到的技巧============== 1.鼠标手势的实现 2.动态生成菜单并相应其点击事件,响应右键,可弹出菜单 3.自画菜单 4.利用资源文件生成菜单,并将其嵌入toolbar中 5.能够对单独网页设定是否显示图片,ActiveX,动画,音乐...等东西 6.实现了前进后退历史记录的显示 7.如何使用代理 8.使用ini文件 9.托盘图标 10.用api生成Toolbar,ListView,TreeView,TabStrip,StatusBar(代码另附) 11.无需注册就可以使用com组件(如vb生成的ActiveX Dll) 12.插件系统,这个系统完全是我自己想出来的,可能未必完善 13.关于webbrowser的东西 14..... ==============各文件夹作用===================== \MDI 主程序源码 \TLBz 主程序缺什么库(如tlb),到这里来找就有了 \ClearCache 一个清除缓存等的工具(源码) \Plugin_LIB 制作插件相关的tlb源码 \Plugins 以写成的插件源码 \Plugins\RssRead rss阅读插件 \release exe \rundllvb 用来调用dll(用于dll型插件)源码(vc6) \IEMouseHand 是一个BHO(browser help object),用于对付那些对话框式弹出页面 ==============其他============== 1.关于库"jccatch.dll#jccatch 1.0 Type Library", 这个是flashget的东西,可以不要,并将frmFlashgetDownload中的Sub AddUrl()中的内容注释掉就行了 2.需要设置一下vb才能在ide中正常运行 tools->options->general, error trapping那里选上 break on unhandled errors 3.这个程序是从2001年开始写的,那时候写的代码在各方面都不成熟,注释就不用说了,也就最近新写或改写的代码才有部分注释,各位大虾就有怪莫怪了. 详细文档,呵呵,懒人啊,以后再慢慢补上吧,现在先把代码发布了再说. 4.部分问题(如鼠标手势)可以到我的blog看看. 5.这个东西我自己一直都在用的,所以不断会有更新,也会上传到我的主页. 6.请不要用作商业用途(估计也没人用吧,^_^) ===================================== by lingll 2005-8-30 blog: http://blog.csdn.net/li
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

冷冰殇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值