Logistic回归解析(python代码)


机器学习之Logistic回归解析及实例应用
22/100
viviliving


<没有设置>
from numpy import *
import matplotlib.pyplot as plt
from matplotlib.font_manager import FontProperties
def loadDataSet():
    dataMat = []
    labelMat = []
    fr = open('testSet.txt')
    #读取文本文件
    for line in fr.readlines():
        #读取全部行内容并且逐行遍历
        lineArr = line.strip().split()
        #对单行进行处理,伤处空格换行符并且切分
        dataMat.append([1.0,float(lineArr[0]),float(lineArr[1])])
        #将文本中的数据的前两个x1,x2和1.0作为x0一起放入列表中
        labelMat.append(int(lineArr[2]))
        #将标签放入标签列表
    return dataMat,labelMat
    
def sigmoid(inX):#sigmoid函数计算
    return 1.0/(1+exp(-inX))
    # return .5 * (1 + tanh(.5 * inX))

def stocGradAscent1(dataMatrix,classLabels,numIter=150):#随机梯度上升算法
    #定义了迭代次数,可更改
    m,n = shape(dataMatrix)
    #获得数据集的行列数
    weights = ones(n)
    #初始化回归系数为1
    weights_arry = array([])
    for j in range(numIter):
        dataIndex = list(range(m))
        #创建数据集索引列表
        for i in range(m):
            alpha = 4/(1.0+j+i)+0.01
            #降低alpha的大小,每次减小1/(j+i)
            randIndex = int(random.uniform(0,len(dataIndex)))
            #产生随机数,即随机的样本
            h = sigmoid(sum(dataMatrix[randIndex]*weights))
            #计算函数值
            error = classLabels[randIndex] - h
            weights = weights + alpha*error*dataMatrix[randIndex]
            #公式计算最佳回归系数
            weights_arry = append(weights_arry,weights,axis=0)
            del(dataIndex[randIndex])
            #删除适用过的数据的索引
    weights_arry = weights_arry.reshape(numIter*m,n)
    return weights,weights_arry
    
def plotWeights(weights_array):
    fig = plt.figure()
    font = FontProperties(fname=r"c:\windows\fonts\simsun.ttc",size=14)
    x = arange(0,len(weights_array),1)
    ax1 = fig.add_subplot(3,1,1)
    ax1.plot(x,weights_array[:,0])
    plt.ylabel('W0')
    ax2 = fig.add_subplot(3,1,2)
    ax2.plot(x,weights_array[:,1])
    plt.ylabel('W1')
    ax3 = fig.add_subplot(3,1,3)
    ax3.plot(x,weights_array[:,2])
    plt.xlabel('迭代次数',fontproperties=font)
    plt.ylabel('W3')
    plt.show()
    
if __name__=='__main__':
    dataArr,labelMat = loadDataSet()
    weights,weights_array = stocGradAscent1(array(dataArr),labelMat)
    plotWeights(weights_array)

 

 

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值