mnist算法实现(1)

本文介绍了如何下载和预处理MNIST手写数字数据集,包括数据结构解析、归一化处理和数据迭代器的实现。此外,还涉及了如何运用神经网络模型进行训练,并演示了如何评估模型性能和调整学习率策略。
摘要由CSDN通过智能技术生成

前言

MNIST是一个广泛用于机器学习和深度学习领域的手写数字图像数据集。它包含了大量的手写数字图片和对应的标签,用于训练和测试图像识别算法。由于其数据规模适中且图像特征清晰,MNIST成为了初学者入门和实践深度学习技术的理想选择。通过训练模型来识别MNIST数据集中的手写数字,可以深入了解神经网络、卷积神经网络等算法的工作原理,为后续的复杂图像识别任务奠定坚实的基础。


一、mnist的下载

如果你是Linux系统可以按照我给的代码运行下载,也可以从官网中下载,链接我贴在下方了

!mkdir dataset
!wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz --output-document=dataset/train-images-idx3-ubyte.gz
!gzip -d dataset/train-images-idx3-ubyte.gz

!wget http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz --output-document=dataset/train-labels-idx1-ubyte.gz
!gzip -d dataset/train-labels-idx1-ubyte.gz

!wget http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz --output-document=dataset/t10k-images-idx3-ubyte.gz
!gzip -d dataset/t10k-images-idx3-ubyte.gz

!wget http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz --output-document=dataset/t10k-labels-idx1-ubyte.gz
!gzip -d dataset/t10k-labels-idx1-ubyte.gz

http://yann.lecun.com/exdb/mnist/ 这是官方的链接

二、数据预处理

1.引入库

其中我们强调一下struct是 Python 的一个内置模块,用于处理二进制数据。它提供了一种方式来打包和解包 C 风格的结构体,这对于与二进制文件、网络协议或底层硬件交互特别有用。struct模块提供了 pack和 unpack这两个主要的函数

import numpy as np
import struct
import random
import matplotlib.pyplot as plt
import pandas as pd
import math

2.读入数据

def load\_labels(file):
    with open (file,'rb') as f:
        data = f.read()      
    magic_num,num_sample = struct.unpack('>ii',data[:8])
    if magic_num !=2049:
        print("输入文件错误")
        
    else:
        labels = data[:8]
        return labels
    
    
def load\_images(file):
    with open (file,'rb') as f:
        data = f.read()
    magic_number, num_samples, image_width, image_height = struct.unpack(">iiii", data[:16])
    if magic_number !=2051:
        print("输入错误")
        
    else:
        image_data = np.asarray(list(data[16:]), dtype=np.uint8).reshape(num_samples, -1)
        return image_data
    
val_labels = load_labels("dataset/t10k-labels-idx1-ubyte")   # 10000,
val_images = load_images("dataset/t10k-images-idx3-ubyte")   # 10000, 784
numdata = val_images.shape[0] # 60000
val_images = np.hstack((val_images / 255 - 0.5, np.ones((numdata, 1))))   # 10000, 785

train_labels = load_labels("dataset/train-labels-idx1-ubyte") # 60000,
train_images = load_images("dataset/train-images-idx3-ubyte") # 60000, 784
numdata = train_images.shape[0] # 60000
train_images = np.hstack((train_images / 255 - 0.5, np.ones((numdata, 1))))  # 60000, 785
train_images.shape
val_images.shape
    

读入二进制的mnist的文件并且根据magic_number判断是否是正确的文件,>是big-endian,可以判断数字的顺序如下图文件,读取完毕后我们把已经读完的数据进行归一化,便于计算机的计算,最后横向堆叠一行单位向量,数据预处理完成。
在这里插入图片描述

三、数据迭代器

class Dataset:
    def \_\_init\_\_(self, images, labels):
        self.images = images
        self.labels = labels

    def \_\_getitem\_\_(self, index):
        return self.images[index], self.labels[index]

    def \_\_len\_\_(self):
        return len(self.images)


class DataLoaderIterator:
    def \_\_init\_\_(self, dataloader):
        self.dataloader = dataloader
        self.cursor = 0
        self.indexs = list(range(self.dataloader.count_data))  # 0, ... 60000
        if self.dataloader.shuffle:
            # 打乱一下
            random.shuffle(self.indexs)
    # 合并batch的数据
    def merge\_to(self, container, b):
        if len(container) == 0:
            for index, data in enumerate(b):
                if isinstance(data, np.ndarray):
                    container.append(data)
                else:
                    container.append(np.array([data], dtype=type(data)))
        else:
            for index, data in enumerate(b):
                container[index] = np.vstack((container[index], data))
        return container
    
    
    def \_\_next\_\_(self):
        if self.cursor >= self.dataloader.count_data:
            raise StopIteration()
            
        batch_data = []
        remain = min(self.dataloader.batch_size, self.dataloader.count_data - self.cursor)  # 256, 128
        for n in range(remain):
            index = self.indexs[self.cursor]
            data = self.dataloader.dataset[index]
            batch_data = self.merge_to(batch_data, data)
            self.cursor += 1
        return batch_data
              
            
class DataLoader:
    def \_\_init\_\_(self, dataset, batch_size, shuffle):
        self.dataset = dataset
        self.shuffle = shuffle
        self.count_data = len(dataset)
        self.batch_size = batch_size
    def \_\_iter\_\_(self):
        return DataLoaderIterator(self)
        

这段代码定义了三个类:Dataset、DataLoaderIterator和DataLoader,它们共同构成了一个简单的数据加载和批处理框架。Dataset类用于存储图像和对应的标签;DataLoaderIterator类作为迭代器,用于按批次从数据集中提取数据,并可以在需要时打乱数据顺序;DataLoader类则负责创建和管理DataLoaderIterator,len(self.images)返回的是images的长度,merge_to对data加到batch_size里面并且判断是否为numpy形式。cursor 是个浮动标签用于检验batch_size所需要的大小。

模型的运用

设定学习率及各种辅助函数

def estimate(plabel, gt_labels, classes):
    plabel = plabel.copy()
    gt_labels = gt_labels.copy()
    match_mask = plabel == classes
    mismatch_mask = plabel != classes
    plabel[match_mask] = 1
    plabel[mismatch_mask] = 0
    
    gt_mask = gt_labels == classes
    gt_mismatch_mask = gt_labels != classes
    gt_labels[gt_mask] = 1
    gt_labels[gt_mismatch_mask] = 0
    
    TP = sum(plabel & gt_labels)
    FP = sum(plabel & (1 - gt_labels))
    FN = sum((1 - plabel) & gt_labels)
    TN = sum((1 - plabel) & (1 - gt_labels))
    
    precision = TP / (TP + FP)
    recall = TP / (TP + FN)
    accuracy = (TP + TN) / (TP + FP + FN + TN)
    F1 = 2 \* (precision \* recall) / (precision + recall)
    return precision, recall, accuracy, F1

def estimate\_val(images, gt_labels, theta, classes):
    predict = sigmoid(val_images @ theta)
    plabel = predict.argmax(1)
    prob = plabel == val_labels
    total_images = images.shape[0]
    accuracy = sum(prob) / total_images
    return accuracy, cross_entropy(predict, one_hot(gt_labels, classes))

def cross\_entropy(predict, gt):
    eps = 1e-4
    predict = np.clip(predict, a_max=1-eps, a_min=eps)  # 裁切
    batch_size = predict.shape[0]
    return -np.sum(gt \* np.log(predict) + (1 - gt) \* np.log(1 - predict)) / batch_size

def lr\_schedule\_cosine(lr_min, lr_max, per_epochs):
    def compute(epoch):
        return lr_min + 0.5 \* (lr_max - lr_min) \* (1 + np.cos(epoch / per_epochs \* np.pi))
    return compute


余弦退火学习率+周期性重启,详情可以参考我之前的文章,cross_entropy计算损失率,estimate来进行评估使用。注:==的优先级大于=

选取模型+模型训练

这段代码实现了一个基本的神经网络训练流程,并通过动态调整学习率来优化训练过程。通过记录和展示训练过程中的损失值和验证性能,我们可以更好地了解模型的训练情况和性能表现。然而,需要注意的是,代码中存在一些未定义的函数和变量(如lr_schedule_cosine、cross_entropy、estimate_val等),这些需要在实际运行代码之前进行定义和实现。注意:if epoch in warm_up_lr 的缩进我之前因为缩进不到位所以程序不能运行。

import matplotlib.pyplot as plt

def sigmoid(x):
 **自我介绍一下,小编13年上海交大毕业,曾经在小公司待过,也去过华为、OPPO等大厂,18年进入阿里一直到现在。**

**深知大多数Python工程师,想要提升技能,往往是自己摸索成长或者是报班学习,但对于培训机构动则几千的学费,着实压力不小。自己不成体系的自学效果低效又漫长,而且极易碰到天花板技术停滞不前!**

**因此收集整理了一份《2024年Python开发全套学习资料》,初衷也很简单,就是希望能够帮助到想自学提升又不知道该从何学起的朋友,同时减轻大家的负担。**

![img](https://img-blog.csdnimg.cn/img_convert/0df2c2c1003291a002bea824d24bb853.png)

![img](https://img-blog.csdnimg.cn/img_convert/6c8955be797896c355b388a4af0e8a24.png)

![img](https://img-blog.csdnimg.cn/img_convert/a9a78eb65cd897190d5b624209b16ea8.png)

![img](https://img-blog.csdnimg.cn/img_convert/feb03381acffb4e1355a09f4b58edbff.png)

![img](https://img-blog.csdnimg.cn/img_convert/6c361282296f86381401c05e862fe4e9.png)

![img](https://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)

**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**

**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**

**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注:Python)**

://img-blog.csdnimg.cn/img_convert/9f49b566129f47b8a67243c1008edf79.png)

**既有适合小白学习的零基础资料,也有适合3年以上经验的小伙伴深入学习提升的进阶课程,基本涵盖了95%以上前端开发知识点,真正体系化!**

**由于文件比较大,这里只是将部分目录大纲截图出来,每个节点里面都包含大厂面经、学习笔记、源码讲义、实战项目、讲解视频,并且后续会持续更新**

**如果你觉得这些内容对你有帮助,可以扫码获取!!!(备注:Python)**

![](https://img-blog.csdnimg.cn/img_convert/36ba2a8db293239943a51d35a2d5f9ad.jpeg)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值