金麟岂是池中物:逻辑斯蒂回归算法及实现

金麟岂是池中物,一遇风云便化龙

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

 逻辑斯蒂回归是很常见的的算法模型,见下图:
在这里插入图片描述
 而对于Mnist数据集,我们是使用多项逻辑斯蒂回归,具体公式见下图:
在这里插入图片描述
 而在具体的使用过程中,我们通常是用softmax回归,如下图所示:
在这里插入图片描述
而损失函数如下,我们要做的就是尽可能让损失函数变小
在这里插入图片描述
 对于损失函数的优化,我们用梯度下降法,进行迭代。
在这里插入图片描述
算法实现如下:

import numpy as np
import time
import math
from math import log
def make_data_set(filename):
    file =open(filename,'r')
    data_label=[]
    data_set=[]
    for line in file:
        value=line.split(',')
        data_label.append(int(value[0]))
        tempt_list=np.asfarray(value[1:])/255.0
        #数据归一化很重要,不然以后的指数运算会溢出
        data_set.append(tempt_list)
    return data_set,data_label

def cal_py(data_set,w,b):
    scores=np.exp(np.dot(data_set,w.T)+b)
    s=np.sum(scores,axis=1,keepdims=True)
    #按行累加,保持行数不变
    if s.all()==False or scores.all()==False:
        return 0
    py=scores/s
    return py
    #py是概率矩阵,列数为类数,行数为训练数据的个数。
def fun_i(data_set,data_label):
    i_mat=np.zeros([len(data_label),10])
    for i in range(len(data_label)):
        i_mat[i][data_label[i]]=1
    return i_mat
    #返回指示矩阵
def train(data_set,data_label,epoch):
    data_set=np.array(data_set)
    w=np.zeros([10,data_set.shape[1]])
    b=np.zeros([1,10])
    learn=0.0001
    i_mat=fun_i(data_set,data_label)
    for i in range(epoch):
        py=cal_py(data_set,w,b)
        x=i_mat-py
        dw=-(1/len(data_label))*np.dot(x.T,data_set)
        db=-(1/len(data_label))*np.sum(x,axis=0)
        w-=dw*learn
        b-=db*learn
    return w,b
def predict(w,b,test_data_set_per):
        x=np.array(test_data_set_per)
        
        p_y=np.dot(x,w.T)
        p_y=p_y.reshape(1,10)
        p_y+=b
        #print(p_y)
        #print(np.argmax(p_y))
        return np.argmax(p_y)
def cal_accuracy(test_data_set,test_data_label,w,b):
    i=0
    cnt=0
    for line in test_data_set:
        re=predict(w,b,line)
        if test_data_label[i]==re:
            cnt+=1
        i+=1
    return cnt/len(test_data_label)
start=time.time()
train_data_set,train_data_label=make_data_set('D:\\bpnetwork\\mnist_train.csv')
test_data_set,test_data_label=make_data_set('D:\\bpnetwork\\mnist_test.csv')  
w,b=train(train_data_set,train_data_label,100)
print(cal_accuracy(test_data_set,test_data_label,w,b))
print(time.time()-start)


 一顿操作猛如虎,一看内存二百五…,该程序至少要准备1G的空闲内存,然鹅俺电脑一共才4G内存,拼拼凑凑才抠出1G的内存。以下是准确率和运行时间。
在这里插入图片描述

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值