BP神经网络minist数据集纯python

'''大数据集用mini_batch,带gui'''
from tkinter import *
import tkinter as tk
from tkinter.simpledialog import *
from tkinter.colorchooser import *
import random
import numpy as np
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris, load_digits
from sklearn.metrics import mean_squared_error
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score
import seaborn as sns
import pickle
from sklearn.model_selection import train_test_split


def load_data():
    # 训练集
    with open('./MINIST_data/train-images.idx3-ubyte') as f:
        loaded = np.fromfile(file=f, dtype=np.uint8)
        train_data = loaded[16:].reshape((60000, 784))
    print(train_data.shape)  # (60000, 784)

    with open('./MINIST_data/train-labels.idx1-ubyte') as f:
        loaded = np.fromfile(file=f, dtype=np.uint8)
        train_labels = loaded[8:]
    print(train_labels.shape)  # (60000,)

    # 测试集
    with open('./MINIST_data/t10k-images.idx3-ubyte') as f:
        loaded = np.fromfile(file=f, dtype=np.uint8)
        test_data = loaded[16:].reshape((10000, 784))
    print(test_data.shape)  # (10000, 784)

    with open('./MINIST_data/t10k-labels.idx1-ubyte') as f:
        loaded = np.fromfile(file=f, dtype=np.uint8)
        test_labels = loaded[8:].reshape((10000))
    print(test_labels.shape)  # (10000,)
    return train_data, train_labels, test_data, test_labels


def sigmoid(x):
    return .5 * (1 + np.tanh(.5 * x))


class MLP:
    def __init__(self, size):

        self.size = size

        self.w = [np.random.normal(0.0, i ** -0.5, (i, j)) for i, j in zip(size[1:], size[:-1])]
        self.b = [np.random.normal(0.0, i ** -0.5, (i, 1)) for i in size[1:]]

    def predict_fun(self, X):
        out = X.T
        for w, b in zip(self.w, self.b):
            net = np.matmul(w, out) + b
            out = sigmoid(net)
        out = out.T
        predict = np.zeros(out.shape[0])
        for i in range(out.shape[0]):
            predict[i] = np.argmax(out[i])
        return predict

    def bp(self, x, y, lr):
        new_w = [np.zeros(w.shape) for w in self.w]
        new_b = [np.zeros(b.shape) for b in self.b]

        out = x.T
        out_list = [out]
        net_list = []
        '''前向传播'''
        for w, b in zip(self.w, self.b):
            net = np.dot(w, out) + b
            out = sigmoid(net)
            out_list.append(out)
            net_list.append(net)
        error = mean_squared_error(out, y)
        print(error)
        '''反向传播'''
        delta = out_list[-1] * (1 - out_list[-1]) * (out_list[-1] - y)
        new_b[-1] = np.sum(delta, axis=1).reshape(-1, 1)
        new_w[-1] = np.dot(delta, out_list[-2].T)
        # #
        for i in range(2, len(self.size)):
            i = -i
            out = out_list[i]
            delta = np.dot(self.w[i + 1].T, delta) * out * (1 - out)

            new_b[i] = np.sum(delta, axis=1).reshape(-1, 1)
            new_w[i] = np.dot(delta, out_list[i - 1].T)
        for i in range(len(new_w)):
            self.w[i] -= lr * new_w[i]
            self.b[i] -= lr * new_b[i]

        return error

    def main(self, dataset, target, lr):
        '''多分类转为oneHot编码'''
        oneHot = np.identity(self.size[-1])
        target = oneHot[target]
        target = target.T
        self.bp(dataset, target, lr)


def save_model(model):
    with open('nmModel.pkl', 'wb') as pkl:
        pickle.dump(model, pkl, pickle.HIGHEST_PROTOCOL)


def load_model():
    pkl = open('nmModel.pkl', 'rb')
    model = pickle.load(pkl)
    return model


if __name__ == '__main__':
    '''建模'''
    model = load_model()
    '''训练'''
    X_train, Y_train, X_test, Y_test = load_data()
    epoch = 10000
    batch_size = 64
    m, n = X_train.shape
    for i in range(epoch):
        random_index = random.sample(range(m), batch_size)
        model.main(X_train[random_index], Y_train[random_index], 0.001)

    '''测试'''
    predict = model.predict_fun(X_test)
    acc = accuracy_score(predict, Y_test)
    print('最终准确率为:', acc)
    c_m = confusion_matrix(predict, Y_test)
    sns.heatmap(c_m, annot=True)
    plt.show()

和以前的差别不大,只是用了批梯度下降

 

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值