【机器学习】- 感知机(mnist数据集)

算法:
在这里插入图片描述
感知机算法计算的超平面S会因为权值的初始值误分类点的选择顺序变化

性能
在对946个测试样本进行测试之后发现,错误率仅为0.42%

数据集和代码可以在这里下载

CODE

from os import listdir
import numpy as np

#读取数据
def readData(fileFolder):
    fileList = listdir(fileFolder)
    trainDataSet = []
    trainLabel = []
    for fileName in fileList:
        #label
        label = [0 for i in range(10)]
        labelDigit = int(fileName.split('_')[0])
        label[labelDigit] = 1
        trainLabel.append(label)
        #data
        ifile = open(fileFolder + "/" + fileName)
        lines = ifile.readlines()
        dataSet = []
        for line in lines:
            line = line.split('\n')[0]
            for i in range(len(line)):
                dataSet.append(int(line[i]))
        trainDataSet.append(dataSet)

    return trainDataSet, trainLabel

def changeFrom(trainLabel, index):
    recY = []
    for label in trainLabel:
        y = [-1]
        if(label[index] == 1):
            y[0] = 1
        recY.append(y)
    return recY

#训练模型
def train_perceptron(trainDataSet, y):
    m = len(trainDataSet)  #  数据量
    n = len(trainDataSet[0]) #  特征数量
    learn_rate = 1
    #初始化模型参数
    w = np.zeros((1, n))
    b = 0
    #开始训练
    hasErrorData = True
    while(hasErrorData == True):
        hasErrorData = False
        for i in range(m):
            data = np.array(trainDataSet[i])
            #错误数据
            if(((w.dot(data.T) + b) * y[i][0]) <= 0):
                hasErrorData = True
                w = w + learn_rate * y[i][0] * data
                b = b + learn_rate * y[i][0]
                # print(w)
                # print(b)
                # print()
    return w,b

def printErrorDigit(data):
    for i in range(1, len(data) + 1):
        print(data[i-1], end = '')
        if(i % 32 == 0):
            print()

# 测试
def test_perceptron(testDataSet, testY, w, b):
    m = len(testDataSet)
    errorCount = 0
    for i in range(m):
        data = np.array(testDataSet[i])
        if((w.dot(data.T) + b) * testY[i][0] > 0):
            print("right: %d, calc: %d" % (testY[i][0], w.dot(data.T) + b))
        else:
            errorCount += 1
            print(printErrorDigit(testDataSet[i]))
            input("press any key to continue:")
    return float(errorCount) / m

def main():

    # 李航例题
    # 一致
    # trainDataSet = [[3, 3], [4, 3], [1, 1]]
    # trainLabel = [[1], [1], [-1]]
    # w, b = train_perceptron(trainDataSet, trainLabel)
    # print(w)
    # print(b)
    # print()

    # mnist
    trainDataSet, trainLabel = readData("data/trainingDigits")
    y = changeFrom(trainLabel, 0)    # 数字0作为正(y = 1), 其他数字作为负(y = -1)
    w, b = train_perceptron(trainDataSet, y)

    testDataSet, testLabel = readData("data/testDigits")
    testY = changeFrom(testLabel, 0)
    errorRate = test_perceptron(testDataSet, testY, w, b)
    print("total: %d, errorRate: %f" % (len(testDataSet), errorRate))

if __name__ == "__main__":
    main()
  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
MLP-MNIST是指使用多层感知(Multilayer Perceptron,简称MLP)模型对MNIST数据集进行分类的任务。MNIST数据集是一个常用于机器学习领域的手写数字识别数据集。它包含了60,000个训练样本和10,000个测试样本,每个样本都是一个28x28的灰度图像,图像上的数字标签表示该图像对应的数字。 MLP是一种经典的前馈神经网络模型,它由多个全连接层组成,每个层都包含了多个神经元。该模型可以通过学习来建立输入图像与对应数字之间的映射关系,从而实现对手写数字的分类任务。 要进行MLP-MNIST数据集的分类任务,可以按照以下步骤进行: 1. 读取数据集:首先,需要将MNIST数据集加载到程序中,可以使用适当的数据读取函数,如TensorFlow中的tf.keras.datasets模块中的load_data()函数。 2. 数据预处理:对于MLP模型,通常需要将图像数据进行平铺(flatten)操作,将二维的图像数据转换为一维的向量作为模型的输入。同时,还需要对图像数据进行归一化处理,将像素值缩放到0到1之间。 3. 初始化模型参数:根据需要选择合适的MLP模型结构,并对模型的参数进行初始化,如权重和偏置。 4. 定义激活函数:MLP模型中的每个神经元通常都会使用激活函数对其输出进行非线性变换,常见的激活函数包括ReLU、sigmoid和tanh等。 5. 防止过拟合:在MLP模型中,为了防止过拟合现象的发生,可以采用一些正则化技术,如权重衰减(weight decay)。 6. 训练模型:使用训练集对MLP模型进行训练,通过反向传播算法不断优化模型参数,使其能够更好地拟合训练数据。 7. 模型评估:使用测试集对训练好的模型进行评估,计算分类准确率等指标,以评估模型的性能。 综上所述,MLP-MNIST数据集是指使用多层感知模型对MNIST数据集进行分类任务的过程。通过适当的数据预处理、模型参数初始化、激活函数定义和防止过拟合等步骤,可以构建出一个能够对手写数字进行准确分类的MLP模型。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值