基于知识蒸馏的深度学习神经网络图像分类系统源码

本文介绍了如何使用CIFAR-10数据集训练深度学习模型,包括使用VGG19作为teacher网络和自定义CNN作为student网络。通过知识蒸馏技术,先训练teacher模型,然后指导student网络学习,实现更高效的图像分类。
摘要由CSDN通过智能技术生成

 第一步:准备数据

cifar10开源数据集
(x_train, y_train), (x_valid,y_valid) = keras.datasets.cifar10.load_data()

第二步:搭建模型

teacher网络是vgg19,student网络是简单的cnn网络:

 由于是十分类问题,直接套用网络肯定是不行,因此会在全连接部分做手脚,参考代码如下:

def build_teacher_model(name='teacher'):
    base_model = keras.applications.VGG19(input_shape=IMAGE_SIZE, include_top=False)
    base_model.trainable = True
    return keras.models.Sequential([
        base_model,
        L.GlobalAvgPool2D(),
        L.Dense(N_CLASSES, activation='softmax')
    ], name=name
    )

另外,student网络,参考代码如下:

def build_student_model(name='student'):
    return keras.models.Sequential([
        L.Conv2D(64, 3, input_shape=IMAGE_SIZE, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.Conv2D(64, 3, padding='same', activation='relu'),
        L.MaxPool2D(pool_size=2),
        L.GlobalAvgPool2D(),
        L.Dense(N_CLASSES,activation='softmax'),
    ],name=name)

第三步:训练代码

1)蒸馏原理参考我的一篇博客知识蒸馏(Distillation)简介_天竺街潜水的八角的博客-CSDN博客

2)先训练teacher网络,再结合训练student网络

teacher_model.compile(
    optimizer=keras.optimizers.Adam(1e-5),
    loss=keras.losses.CategoricalCrossentropy(from_logits=True),
    metrics=['accuracy']
)

history = teacher_model.fit(
    d_train.shuffle(1024, 19).batch(BATCH_SIZE),
    validation_data=d_valid.shuffle(1024, 19).batch(BATCH_SIZE),
    epochs=T_EPOCHS,
    callbacks=nn_callbacks(),
    batch_size=BATCH_SIZE
)


distiller = Distiller(student_model, teacher_model, tf.nn.softmax)
distiller.compile(
    optimizer=keras.optimizers.Adam(),
    metrics=['accuracy'],
    student_loss_fn=keras.losses.CategoricalCrossentropy(from_logits=True),
    distillation_loss_fn=keras.losses.KLDivergence(),
    alpha=0.7,
    temperature=100,
)
history_distillation = distiller.fit(
    d_train.shuffle(1024, 19).batch(BATCH_SIZE),
    validation_data=d_valid.shuffle(1024, 19).batch(BATCH_SIZE),
    epochs=S_EPOCHS, callbacks=nn_callbacks(), batch_size=BATCH_SIZE
)

第四步:统计训练过程

第五步:搭建GUI界面

第六步:整个工程的内容

有训练代码和训练好的模型以及训练过程,提供数据,提供GUI界面代码,主要使用方法可以参考里面的“文档说明_必看.docx”

 代码的下载路径(新窗口打开链接)基于知识蒸馏的cifar10图像分类系统源码

有问题可以私信或者留言,有问必答

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值