tensorflow2 cifar10 模型训练 demo

本文档介绍了使用TensorFlow 2.8、CUDA 11.2和cuDNN 8.1.1在Windows环境下训练CIFAR-10数据集的步骤。首先,通过keras.datasets.cifar10加载并预处理数据,然后构建一个简单的卷积神经网络模型进行训练。在训练过程中遇到了cuDNN版本不匹配的问题,通过更新和配置环境变量解决了问题。最终,模型在验证集上达到了约70.76%的准确率。
摘要由CSDN通过智能技术生成

引言

原文地址: tensorflow2 cifar10 模型训练 demo

欢迎访问我的博客: http://blog.duhbb.com/

文本使用 tensorflow 2.8, CUDA 11.2 以及 cuDNN 8.1.1 训练了 cifar10 数据集. 代码没有那么重要, 主要是完成了环境的安装以及各种问题排查, 最后用一个简单的网络结构跑了一下训练. 如果本文对你有用, 麻烦不吝点个赞; 如果有啥问题, 请不要犹豫, 赶紧联系我.

下载数据集和查看数据

import tensorflow as tf
from keras import datasets, layers, models
import matplotlib.pyplot as plt

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

# 归一化处理
train_images, test_images = train_images / 255.0, test_images / 255.0

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

plt.figure(figsize=(10, 10))
for i in range(10):
    plt.subplot(5, 5, i + 1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)
    plt.xlabel(class_names[train_labels[i][0]])
plt.show()

下面的这一行会下载对应的数据集:

(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

具体的下载路径就是这里:

C:\Users\tuhoo\.keras\datasets

file

上面的代码跑完, 我们就可以看到具体的图片了:

如何自己下载数据

keras cifar10.load_data() 自己下载数据

keras 下载数据出错

使用 keras 时, 导入cifar10数据会自动下载 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz, 但是有时会下载出错;

解决方法

自己下载cifar-10-python.tar.gz, 然后将文件放到对应路径(~./keras/datasets/) .

将文件名称 cifar-10-python.tar.gz 改为 cifar-10-batches-py.tar.gz .

这两步操作综合, 在所下载文件 cifar-10-python.tar.gz 的根目录下, 使用如下命令:

cp cifar-10-python.tar.gz ~./keras/datasets/cifar-10-batches-py.tar.gz

原文链接:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值