引言
原文地址: tensorflow2 cifar10 模型训练 demo
欢迎访问我的博客: http://blog.duhbb.com/
文本使用 tensorflow 2.8, CUDA 11.2 以及 cuDNN 8.1.1 训练了 cifar10 数据集. 代码没有那么重要, 主要是完成了环境的安装以及各种问题排查, 最后用一个简单的网络结构跑了一下训练. 如果本文对你有用, 麻烦不吝点个赞; 如果有啥问题, 请不要犹豫, 赶紧联系我.
下载数据集和查看数据
import tensorflow as tf
from keras import datasets, layers, models
import matplotlib.pyplot as plt
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
# 归一化处理
train_images, test_images = train_images / 255.0, test_images / 255.0
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure(figsize=(10, 10))
for i in range(10):
plt.subplot(5, 5, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i][0]])
plt.show()
下面的这一行会下载对应的数据集:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()
具体的下载路径就是这里:
C:\Users\tuhoo\.keras\datasets
上面的代码跑完, 我们就可以看到具体的图片了:
如何自己下载数据
keras cifar10.load_data() 自己下载数据
keras 下载数据出错
使用 keras 时, 导入cifar10数据会自动下载 https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz, 但是有时会下载出错;
解决方法
自己下载cifar-10-python.tar.gz, 然后将文件放到对应路径(~./keras/datasets/)
.
将文件名称 cifar-10-python.tar.gz
改为 cifar-10-batches-py.tar.gz
.
这两步操作综合, 在所下载文件 cifar-10-python.tar.gz 的根目录下, 使用如下命令:
cp cifar-10-python.tar.gz ~./keras/datasets/cifar-10-batches-py.tar.gz
原文链接: