前言
MNIST是一个广泛用于机器学习和深度学习领域的手写数字图像数据集。它包含了大量的手写数字图片和对应的标签,用于训练和测试图像识别算法。由于其数据规模适中且图像特征清晰,MNIST成为了初学者入门和实践深度学习技术的理想选择。通过训练模型来识别MNIST数据集中的手写数字,可以深入了解神经网络、卷积神经网络等算法的工作原理,为后续的复杂图像识别任务奠定坚实的基础。
一、mnist的下载
如果你是Linux系统可以按照我给的代码运行下载,也可以从官网中下载,链接我贴在下方了
!mkdir dataset
!wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz --output-document=dataset/train-images-idx3-ubyte.gz
!gzip -d dataset/train-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz --output-document=dataset/train-labels-idx1-ubyte.gz
!gzip -d dataset/train-labels-idx1-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz --output-document=dataset/t10k-images-idx3-ubyte.gz
!gzip -d dataset/t10k-images-idx3-ubyte.gz
!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz --output-document=dataset/t10k-labels-idx1-ubyte.gz
!gzip -d dataset/t10k-labels-idx1-ubyte.gz
http://yann.lecun.com/exdb/mnist/ 这是官方的链接
二、数据预处理
1.引入库
其中我们强调一下struct是 Python 的一个内置模块,用于处理二进制数据。它提供了一种方式来打包和解包 C 风格的结构体,这对于与二进制文件、网络协议或底层硬件交互特别有用。struct模块提供了 pack和 unpack这两个主要的函数
import numpy as np
import struct
import random
import matplotlib.pyplot as plt
import pandas as pd
import math
2.读入数据
def load\_labels(file):
with open (file,'rb') as f:
data = f.read()
magic_num,num_sample = struct.unpack('>ii',data[:8])
if magic_num !=2049:
print("输入文件错误")
else:
labels = data[:8]
return labels
def load\_images(file):
with open (file,'rb') as f:
data = f.read()
magic_number, num_samples, image_width, image_height = struct.unpack(">iiii", data[:16])
if magic_number !=2051:
print("输入错误")
else:
image_data = np.asarray(list(data[16:]), dtype=np.uint8).reshape(num_samples, -1)
return image_data
val_labels = load_labels("dataset/t10k-labels-idx1-ubyte") # 10000,
val_images = load_images("dataset/t10k-images-idx3-ubyte") # 10000, 784
numdata = val_images.shape[0] # 60000
val_images = np.hstack((val_images / 255 - 0.5, np.ones((numdata, 1)))) # 10000, 785
train_labels = load_labels("dataset/train-labels-idx1-ubyte") # 60000,
train_images = load_images("dataset/train-images-idx3-ubyte") # 60000, 784
numdata = train_images.shape[0] # 60000
train_images = np.hstack((train_images / 255 - 0.5, np.ones((numdata, 1)))) # 60000, 785
train_images.shape
val_images.shape
读入二进制的mnist的文件并且根据magic_number判断是否是正确的文件,>是big-endian,可以判断数字的顺序如下图文件,读取完毕后我们把已经读完的数据进行归一化,便于计算机的计算,最后横向堆叠一行单位向量,数据预处理完成。
三、数据迭代器
class Dataset:
def \_\_init\_\_(self, images, labels):
self.images = images
self.labels = labels
def \_\_getitem\_\_(self, index):
return self.images[index], self.labels[index]
def \_\_len\_\_(self):
return len(self.images)
class DataLoaderIterator:
def \_\_init\_\_(self, dataloader):
self.dataloader = dataloader
self.cursor = 0
self.indexs = list(range(self.dataloader.count_data)) # 0, ... 60000
if self.dataloader.shuffle:
# 打乱一下
random.shuffle(self.indexs)
# 合并batch的数据
def merge\_to(self, container, b):
if len(container) == 0:
for index, data in enumerate(b):
if isinstance(data, np.ndarray):
container.append(data)
else:
container.append(np.array([data], dtype=type(data)))
else:
for index, data in enumerate(b):
container[index] = np.vstack((container[index], data))
return container
def \_\_next\_\_(self):
if self.cursor >= self.dataloader.count_data:
raise StopIteration()
batch_data = []
remain = min(self.dataloader.batch_size, self.dataloader.count_data - self.cursor) # 256, 128
for n in range(remain):
index = self.indexs[self.cursor]
data = self.dataloader.dataset[index]
batch_data = self.merge_to(batch_data, data)
self.cursor += 1
return batch_data
class DataLoader:
def \_\_init\_\_(self, dataset, batch_size, shuffle):
self.dataset = dataset
self.shuffle = shuffle
self.count_data = len(dataset)
self.batch_size = batch_size
def \_\_iter\_\_(self):
return DataLoaderIterator(self)
这段代码定义了三个类:Dataset、DataLoaderIterator和DataLoader,它们共同构成了一个简单的数据加载和批处理框架。Dataset类用于存储图像和对应的标签;DataLoaderIterator类作为迭代器,用于按批次从数据集中提取数据,并可以在需要时打乱数据顺序;DataLoader类则负责创建和管理DataLoaderIterator,len(self.images)返回的是images的长度,merge_to对data加到batch_size里面并且判断是否为numpy形式。cursor 是个浮动标签用于检验batch_size所需要的大小。
模型的运用
设定学习率及各种辅助函数
def estimate(plabel, gt_labels, classes):
plabel = plabel.copy()
gt_labels = gt_labels.copy()
match_mask = plabel == classes
mismatch_mask = plabel != classes
plabel[match_mask] = 1
plabel[mismatch_mask] = 0
gt_mask = gt_labels == classes
gt_mismatch_mask = gt_labels != classes
gt_labels[gt_mask] = 1
gt_labels[gt_mismatch_mask] = 0
TP = sum(plabel & gt_labels)
FP = sum(plabel & (1 - gt_labels))
FN = sum((1 - plabel) & gt_labels)
TN = sum((1 - plabel) & (1 - gt_labels))
precision = TP / (TP + FP)
recall = TP / (TP + FN)
accuracy = (TP + TN) / (TP + FP + FN + TN)
F1 = 2 \* (precision \* recall) / (precision + recall)
return precision, recall, accuracy, F1
def estimate\_val(images, gt_labels, theta, classes):
predict = sigmoid(val_images @ theta)
plabel = predict.argmax(1)
prob = plabel == val_labels
total_images = images.shape[0]
accuracy = sum(prob) / total_images
return accuracy, cross_entropy(predict, one_hot(gt_labels, classes))
def cross\_entropy(predict, gt):
eps = 1e-4
predict = np.clip(predict, a_max=1-eps, a_min=eps) # 裁切
batch_size = predict.shape[0]
return -np.sum(gt \* np.log(predict) + (1 - gt) \* np.log(1 - predict)) / batch_size
def lr\_schedule\_cosine(lr_min, lr_max, per_epochs):
def compute(epoch):
return lr_min + 0.5 \* (lr_max - lr_min) \* (1 + np.cos(epoch / per_epochs \* np.pi))
return compute
余弦退火学习率+周期性重启,详情可以参考我之前的文章,cross_entropy计算损失率,estimate来进行评估使用。注:==的优先级大于=
选取模型+模型训练
这段代码实现了一个基本的神经网络训练流程,并通过动态调整学习率来优化训练过程。通过记录和展示训练过程中的损失值和验证性能,我们可以更好地了解模型的训练情况和性能表现。然而,需要注意的是,代码中存在一些未定义的函数和变量(如lr_schedule_cosine、cross_entropy、estimate_val等),这些需要在实际运行代码之前进行定义和实现。注意:if epoch in warm_up_lr 的缩进我之前因为缩进不到位所以程序不能运行。
import matplotlib.pyplot as plt
def sigmoid(x):
**自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。**
**深知大多数Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!**
**因此收集整理了一份《2024年Python开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。**
![img](https://img-blog.csdnimg.cn/img_convert/0df2c2c1003291a002bea824d24bb853.png)
![img](https://img-blog.csdnimg.cn/img_convert/6c8955be797896c355b388a4af0e8a24.png)
![img](https://img-blog.csdnimg.cn/img_convert/a9a78eb65cd897190d5b624209b16ea8.png)
![img](https://img-blog.csdnimg.cn/img_convert/feb03381acffb4e1355a09f4b58edbff.png)
![img](https://img-blog.csdnimg.cn/img_convert/6c361282296f86381401c05e862fe4e9.png)
![img](https://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)
**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**
**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**
**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注:Python)**
://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)
**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**
**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**
**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注:Python)**
![](https://img-blog.csdnimg.cn/img_convert/36ba2a8db293239943a51d35a2d5f9ad.jpeg)