- 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
- 🍖 原作者:K同学啊
一、前期准备
1.1数据集下载
https://www.cs.toronto.edu
1.2设置gpu
gpus = tf.config.list_physical_devices('GPU')
if gpus:
gpu0 = gpus[0]
tf.config.experimental.set_memory_growth(gpu0, True) # 设置显存按需增长
tf.config.set_visible_devices([gpu0],"GPU")
print("GPU available.")
else:
print("GPU cannot be find,using CPU instead.")
1.3 导入数据
import tensorflow as tf
import os
import tarfile
import numpy as np
import pickle
from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt
1.4 load
# 解压路径
extract_path = 'D:/others/pycharm/pythonProject/cifar-10-python/cifar-10-batches-py'
# 检查文件是否已解压
if not os.path.exists(extract_path):
# 本地数据集文件路径
local_file_path = 'D:/others/pycharm/pythonProject/cifar-10-python.tar.gz' #我的文件路径
os.makedirs(extract_path, exist_ok=True)
# 解压数据集文件
with tarfile.open(local_file_path, 'r:gz') as tar:
tar.extractall(path=os.path.dirname(extract_path))
# 加载解压后的数据集
def load_cifar10_batch(cifar10_dataset_folder_path, batch_id):
with open(os.path.join(cifar10_dataset_folder_path, 'data_batch_' + str(batch_id)), mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
labels = np.array(batch['labels'])
return features, labels
1.4 load data
# 加载解压后的数据集
def load_cifar10_batch(cifar10_dataset_folder_path, batch_id):
with open(os.path.join(cifar10_dataset_folder_path, 'data_batch_' + str(batch_id)), mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
labels = np.array(batch['labels'])
return features, labels
def load_cifar10_test_batch(cifar10_dataset_folder_path):
with open(os.path.join(cifar10_dataset_folder_path, 'test_batch'), mode='rb') as file:
batch = pickle.load(file, encoding='latin1')
features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
labels = np.array(batch['labels'])
return features, labels
cifar10_dataset_folder_path = extract_path
# 加载训练数据
train_images = []
train_labels = []
for i in range(1, 6):
features, labels = load_cifar10_batch(cifar10_dataset_folder_path, i)
train_images.append(features)
train_labels.append(labels)
train_images = np.concatenate(train_images)
train_labels = np.concatenate(train_labels)
# 加载测试数据
test_images, test_labels = load_cifar10_test_batch(cifar10_dataset_folder_path)
1.6 可视化结果
# keshi打印归一化后的图像尺寸
print(train_images.shape, test_images.shape, train_labels.shape, test_labels.shape)
# 显示前 20 张图像
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
plt.figure(figsize=(20, 10))
for i in range(20):
plt.subplot(5, 10, i + 1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.imshow(train_images[i], cmap=plt.cm.binary)
plt.xlabel(class_names[train_labels[i]])
plt.show()
二、搭建模型(CNN)
2.1模型结构
model = models.Sequential([
layers.Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), # 卷积层1,卷积核3*3
layers.MaxPooling2D((2, 2)), # 池化层1,2*2采样
layers.Conv2D(64, (3, 3), activation='relu'), # 卷积层2,卷积核3*3
layers.MaxPooling2D((2, 2)), # 池化层2,2*2采样
layers.Conv2D(64, (3, 3), activation='relu'), # 卷积层3,卷积核3*3
layers.Flatten(), # Flatten层,连接卷积层与全连接层
layers.Dense(64, activation='relu'), # 全连接层,特征进一步提取
layers.Dense(10) # 输出层,输出预期结果
])
model.summary() # 打印网络结构
2.2 模型编译
history = model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))
2.3 模型训练
history = model.fit(train_images, train_labels, epochs=10,
validation_data=(test_images, test_labels))
output
2024-06-05 20:53:33.335456: W tensorflow/core/framework/cpu_allocator_impl.cc:82] Allocation of 614400000 exceeds 10% of free system memory.
Epoch 1/10
2024-06-05 20:53:35.138652: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8100
1563/1563 [==============================] - 15s 5ms/step - loss: 1.5403 - accuracy: 0.4383 - val_loss: 1.2772 - val_accuracy: 0.5382
Epoch 2/10
1563/1563 [==============================] - 7s 4ms/step - loss: 1.1675 - accuracy: 0.5840 - val_loss: 1.1246 - val_accuracy: 0.5962
Epoch 3/10
1563/1563 [==============================] - 7s 4ms/step - loss: 1.0179 - accuracy: 0.6398 - val_loss: 1.0707 - val_accuracy: 0.6224
Epoch 4/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.9142 - accuracy: 0.6777 - val_loss: 0.9220 - val_accuracy: 0.6748
Epoch 5/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.8465 - accuracy: 0.7021 - val_loss: 0.8923 - val_accuracy: 0.6891
Epoch 6/10
1563/1563 [==============================] - 7s 5ms/step - loss: 0.7853 - accuracy: 0.7248 - val_loss: 0.8956 - val_accuracy: 0.6873
Epoch 7/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.7328 - accuracy: 0.7426 - val_loss: 0.8921 - val_accuracy: 0.6965
Epoch 8/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.6911 - accuracy: 0.7557 - val_loss: 0.8859 - val_accuracy: 0.6994
Epoch 9/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.6505 - accuracy: 0.7705 - val_loss: 0.8903 - val_accuracy: 0.6986
Epoch 10/10
1563/1563 [==============================] - 7s 4ms/step - loss: 0.6148 - accuracy: 0.7841 - val_loss: 0.8611 - val_accuracy: 0.7093
进程已结束,退出代码为 0