系统了解 Softmax Regression 原理及实践

    上一个讲的 Logistic Regression 属于二分类问题,如果处理多分类问题如手写字辨识就需要 Softmax Regression 算法,在Softmax Regression 算法中任意两个分类之间是线性的。

关于手写字的辨识在之前我就写过一篇文章 https://blog.csdn.net/IMWTJ123/article/details/84072995,手写字是辨识 {0,1,....,9} 这10 个数,如下图所示,手写字选自 MNIST 数据集,我就不做过多的解说了。

一、Softmax Regression 算法模型

假设有 m 个训练样本 \left \{ (X^{1},y^{1}) ,(X^{2},y^{2}),...,(X^{m},y^{m}\)\right \},类标记为:y^{i}\in \left \{ 0,1,...,9 \right \},假设函数为每个样本估计其属于的类别的概率P\left ( y=j\mid X \right ),则具体的假设函数为:

则对每一个样本估计其所属的类别的概率为:

                                                P\left ( y^{i} =j\mid X^{i};\theta \right )=\frac{e^{\theta _{j}^{T}X^{i}}}{\sum_{l=1}^{k}e^{\theta _{j}^{T}}X^{i}}

损失函数J\left ( \theta \right ),引入指示函数I\left ( x \right ),当 x=false 时,为 0,当x=true 时 为1。

则损失函数及其梯度推导如下图:

则梯度下降法的公式为:

                                                        \theta _{j}=\theta _{j}-\alpha \bigtriangledown _{\theta _{j}}J\left ( \theta \right )

二、Logistic Regression 算法和 Softmax Regression 算法的关联

Logistic Regression 算法是Softmax Regression 算法的特殊情况,如下图所示: 

 

所以 k=2 时,Logistic Regression 算法和 Softmax Regression 算法等价。

三、Softmax Regression 算法的实践

# coding:UTF-8

#Softmax_train.py
import numpy as np

#梯度更新函数
def gdAscent(feature_data,label_data,k,maxCycle,alpha):
    '''
    input:feature_data(mat):特征
          label_data(mat):标签
          k(int):类别个数
          maxCycle(int):最大的迭代次数
          alpha(float):学习率
    output:weights(mat):权重
    '''
    m,n=np.shape(feature_data)
    weights=np.mat(np.ones((n,k))) #权重的初始化
    i=0
    while i <=maxCycle:
        err=np.exp(feature_data*weights)
        if i % 1000 ==0:
            print("\t-----iter: ",i,\
                  ",cost: ",cost(err,label_data))
        rowsum=-err.sum(axis=1)
        rowsum=rowsum.repeat(k,axis=1)
        err=err/rowsum
        for x in range(m):
            err[x,label_data[x,0]]+=1
        weights=weights+(alpha/m)*feature_data.T*err
        i+=1
    return weights

def cost(err,label_data):
    '''计算损失函数值'''
    m=np.shape(err)[0]
    sum_cost=0.0
    for i in range(m):
        if err[i,label_data[i,0]]/np.sum(err[i,:])>0:
            sum_cost-=np.log(err[i,label_data[i,0]]/np.sum(err[i,:]))
        else:
            sum_cost-=0
    return sum_cost/m

#导入训练数据的函数

def load_data(inputfile):
    f = open(inputfile)  # 打开文件
    feature_data = []
    label_data = []
    for line in f.readlines():
        feature_tmp = []
        feature_tmp.append(1)  # 偏置项
        lines = line.strip().split("\t")
        for i in range(len(lines) - 1):
            feature_tmp.append(float(lines[i]))
        label_data.append(int(lines[-1]))
        feature_data.append(feature_tmp)
    f.close()  # 关闭文件
    return np.mat(feature_data), np.mat(label_data).T, len(set(label_data))
        
def save_model(file_name,weights):
    f_w=open(file_name,"w")
    m,n=np.shape(weights)
    for i in range(m):
        w_tmp=[]
        for j in range(n):
            w_tmp.append(str(weights[i,j]))
        f_w.write("\t".join(w_tmp)+"\n")
    f_w.close()

#模型训练的主函数

if __name__ == "__main__":

    inputfile = "D:/anaconda4.3/spyder_work/test.txt"
    print ("---------- 1.load data ------------")
    feature, label, k = load_data(inputfile)
    print ("---------- 2.training ------------")
    weights = gdAscent(feature, label, k, 10000, 0.4)
    print ("---------- 3.save model ------------")
    save_model("weights", weights)

weights:
-3.5726865289846708	26.124877704630585	-33.6100448734133	15.057853697767758
2.3994919672871533	0.40175773085236655	-0.4381233223631104	1.6368736242235842
2.440495875207079	-3.91082999515059	5.437464787120998	0.0328693328225254
# -*- coding: utf-8 -*-
"""
Created on Thu Feb 21 19:35:44 2019

@author: 2018061801
"""
#Softmax_test.py
import numpy as np
import random as rd

def load_weights(weights_path):

    '''导入训练好的Softmax模型
    input:  weights_path(string)权重的存储位置
    output: weights(mat)将权重存到矩阵中
            m(int)权重的行数
            n(int)权重的列数
    '''
    f = open(weights_path)
    w = []
    for line in f.readlines():
        w_tmp = []
        lines = line.strip().split("\t")
        for x in lines:
            w_tmp.append(float(x))
        w.append(w_tmp)
    f.close()
    weights = np.mat(w)
    m, n = np.shape(weights)
    return weights, m, n

def load_data(num, m):
    '''导入测试数据
    input:  num(int)生成的测试样本的个数
            m(int)样本的维数
    output: testDataSet(mat)生成测试样本
    '''
    testDataSet = np.mat(np.ones((num, m)))
    for i in range(num):
        testDataSet[i, 1] = rd.random() * 6 - 3#随机生成[-3,3]之间的随机数
        testDataSet[i, 2] = rd.random() * 15#随机生成[0,15]之间是的随机数
    return testDataSet

def predict(test_data, weights):
    '''利用训练好的Softmax模型对测试数据进行预测
    input:  test_data(mat)测试数据的特征
            weights(mat)模型的权重
    output: h.argmax(axis=1)所属的类别
    '''
    h = test_data * weights
    return h.argmax(axis=1)#获得所属的类别

def save_result(file_name, result):
    '''保存最终的预测结果
    input:  file_name(string):保存最终结果的文件名
            result(mat):最终的预测结果
    '''
    f_result = open(file_name, "w")
    m = np.shape(result)[0]
    for i in range(m):
        f_result.write(str(result[i, 0]) + "\n")
    f_result.close()

if __name__ == "__main__":
    # 1、导入Softmax模型
    print ("---------- 1.load model ------------")
    w, m , n = load_weights("weights")
    # 2、导入测试数据
    print ("---------- 2.load data ------------")
    test_data = load_data(4000, m)
    # 3、利用训练好的Softmax模型对测试数据进行预测
    print ("---------- 3.get Prediction ------------")
    result = predict(test_data, w)
    # 4、保存最终的预测结果
    print ("---------- 4.save prediction ------------")
    save_result("result", result)
#test.txt
-0.017612	14.053064	2
-1.395634	4.662541	3
-0.752157	6.53862	3
-1.322371	7.152853	3
0.423363	11.054677	2
0.406704	7.067335	3
0.667394	12.741452	2
-2.46015	6.866805	3
0.569411	9.548755	0
-0.026632	10.427743	2
0.850433	6.920334	3
1.347183	13.1755	2
1.176813	3.16702	3
-1.781871	9.097953	2
-0.566606	5.749003	3
0.931635	1.589505	1
-0.024205	6.151823	3
-0.036453	2.690988	1
-0.196949	0.444165	1
1.014459	5.754399	3
1.985298	3.230619	3
-1.693453	-0.55754	1
-0.576525	11.778922	2
-0.346811	-1.67873	1
-2.124484	2.672471	1
1.217916	9.597015	0
-0.733928	9.098687	0
-3.642001	-1.618087	1
0.315985	3.523953	3
1.416614	9.619232	0
-0.386323	3.989286	3
0.556921	8.294984	0
1.224863	11.58736	2
-1.347803	-2.406051	1
1.196604	4.951851	3
0.275221	9.543647	0
0.470575	9.332488	0
-1.889567	9.542662	2
-1.527893	12.150579	2
-1.185247	11.309318	2
-0.445678	3.297303	3
1.042222	6.105155	3
-0.618787	10.320986	2
1.152083	0.548467	1
0.828534	2.676045	3
-1.237728	10.549033	2
-0.683565	-2.166125	1
0.229456	5.921938	3
-0.959885	11.555336	2
0.492911	10.993324	2
0.184992	8.721488	0
-0.355715	10.325976	2
-0.397822	8.058397	0
0.824839	13.730343	2
1.507278	5.027866	3
0.099671	6.835839	3
-0.344008	10.717485	2
1.785928	7.718645	0
-0.918801	11.560217	2
-0.364009	4.7473	3
-0.841722	4.119083	3
0.490426	1.960539	1
-0.007194	9.075792	0
0.356107	12.447863	2
0.342578	12.281162	2
-0.810823	-1.466018	1
2.530777	6.476801	3
1.296683	11.607559	2
0.475487	12.040035	2
-0.783277	11.009725	2
0.074798	11.02365	2
-1.337472	0.468339	1
-0.102781	13.763651	2
-0.147324	2.874846	3
0.518389	9.887035	0
1.015399	7.571882	0
-1.658086	-0.027255	1
1.319944	2.171228	1
2.056216	5.019981	3
-0.851633	4.375691	3
-1.510047	6.061992	3
-1.076637	-3.181888	1
1.821096	10.28399	0
3.01015	8.401766	0
-1.099458	1.688274	1
-0.834872	-1.733869	1
-0.846637	3.849075	3
1.400102	12.628781	2
1.752842	5.468166	3
0.078557	0.059736	1
0.089392	-0.7153	1
1.825662	12.693808	2
0.197445	9.744638	0
0.126117	0.922311	1
-0.679797	1.22053	1
0.677983	2.556666	1
0.761349	10.693862	0
-2.168791	0.143632	1
1.38861	9.341997	0
0.317029	14.739025	2
-2.65887965178	0.658328066452	1
-2.30615885683	11.5036718065	2
-2.83005963556	7.30810428189	3
-2.30319006285	3.18958964564	1
-2.31349250532	4.41749905123	3
-2.71157223048	0.21599278192	1
-2.99935111344	14.5766538514	2
-2.50329272687	12.7274016382	2
-2.14191210185	9.75999136268	2
-2.21409612618	9.25234159289	2
-2.0503599261	1.87312594247	1
-2.99747377006	2.82404034943	1
-2.39019233623	1.88778487771	1
-2.00981101171	13.0015287952	2
-2.06105014551	7.26924117028	3
-2.94028883652	10.8418044558	2
-2.56811396636	1.31240093493	1
-2.89942462914	7.47932555859	3
-2.83349151782	0.292728283929	1
-2.16467022383	4.62184237142	3
2.02604290795	6.68200376515	3
2.3755881562	9.3838379637	0
2.48299208843	9.75753701005	0
2.65108044441	9.39059526201	0
2.49422603944	11.856131521	0
2.47215954581	4.83431641068	3
2.26731525725	5.64891602081	3
2.33628075296	10.4603294628	0
2.4548064459	9.90879879651	0
2.13147505967	8.99561368732	0
2.86925733903	4.26531919929	3
2.05715970133	4.97240425903	3
2.14839753847	8.91032469409	0
2.17630437606	5.76122354509	3
2.86205491781	11.630342945	0

参考文献:

赵志勇《Python 机器学习算法》

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值