数据集百度网盘,就是西瓜书3.0a的数据。首先,加载数据,load_data(file)函数。
def load_data(file):
s =[]
with open(file) as f:
for line in f.readlines():
line = line.replace('\n','') #追行读取 type(line) = str
s.append(line.split(' ')) #空格分开
return s
将这个数据可视化下,大致是这个样子。
其中,红,蓝分别表示好瓜和坏瓜。可以看出,用一条直线,还是不好划分的。
这部分代码如下:
file = '../data/3_0a.txt' #文件地址
s = load_data(file)
print(type(s))
x =[] #存好瓜的含糖率
y= [] #好瓜的密度
x1 =[] #坏瓜含糖率
y1 = [] #坏瓜密度
for i in range(1,8): #读取好瓜
for j in range(len(s[i])):
if j == 2:
x.append(float(s[i][j]))
if j ==3:
y.append(float(s[i][j]))
for i in range(8,len(s)): #坏瓜
for j in range(len(s[i])):
if j == 2:
x1.append(float(s[i][j]))
if j ==3:
y1.append(float(s[i][j]))
import pylab as pl
pl.plot(x,y,'o')
pl.plot(x1,y1,'ro')
pl.show()
我们接下来用对数几率回归模型,具体的公式可以看周志华的《机器学习》第三章的(3.27)这个式子,其他的地方也有。这个是没约束的优化问题,直接用梯度下降法,求导有问题的,可以
机器学习求导..
3.27公式中,yi是样本的结果,好瓜是1,坏瓜是0.xi是样本的属性,我们这里有两个属性。下面就是从前面读取的数据把xi,yi读出来。然后把这个值带入梯度下降法中的导数项。w,b的初始值随便设置个就行。迭代算吧。代码如下:
import numpy as np
import pylab as plt
import my_load_data as mld #就是前面的那个函数,这段可以删除,直接把上面的load_data函数放到这里也行。
file = '../data/3_0a.txt'
s = mld.load_data(file)
x = np.mat(np.zeros((17,3))) #why (())? #初始化矩阵。用ndarray无法进行矩阵乘法这类运算,所以要用mat。
y = np.mat(np.zeros((17,1)))
for i in range(1,18,1): #yi ,xi读取
x[i-1] = np.mat([float(s[i][2]),float(s[i][3]),1])
if s[i][1]=='是':
y[i-1] = np.mat([1])
else:
y[i-1] =np.mat([0])
start = np.mat([[0.1],[10],[8]]) #w,b的初始化。这里有三个数,[w1;w2;b]
i = 0
xishu =0.01
while i<2*10**5: #一万次差不多就可以了
s = 0
for j in range(17): #3.27前面有个i=1到i=m的求和,就是这里。
startT =np.transpose(start)
xT =np.transpose(x[j])
bx = startT*xT
bx_1 = np.array(bx)[0][0]
c = -y[j]*x[j]+(np.exp(bx_1)/(1+np.exp(bx_1)))*x[j] #导数,写的太难看,请忽略
s =s+c
s_1 = np.transpose(s) #导数
new = start - xishu*s_1 #梯度下降公式,这里大家应该很熟悉
start =new
i=i+1
if i%10000 ==0:
print('no%s'%i,'start is %s'%start)
print(start)
迭代结果:
no10000 start is [[ 2.98758124]
[ 11.91671654]
[ -4.21286642]]
no20000 start is [[ 3.13439493]
[ 12.43023225]
[ -4.39732669]]
no30000 start is [[ 3.15464273]
[ 12.50721714]
[ -4.42401375]]
no40000 start is [[ 3.15776018]
[ 12.5190382 ]
[ -4.42811559]]
no50000 start is [[ 3.15824169]
[ 12.52086253]
[ -4.42874883]]
no60000 start is [[ 3.15831607]
[ 12.52114431]
[ -4.42884664]]
no70000 start is [[ 3.15832756]
[ 12.52118784]
[ -4.42886175]]
no80000 start is [[ 3.15832934]
[ 12.52119456]
[ -4.42886408]]
no90000 start is [[ 3.15832961]
[ 12.5211956 ]
[ -4.42886444]]
no100000 start is [[ 3.15832965]
[ 12.52119576]
[ -4.4288645 ]]
no110000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no120000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no130000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no140000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no150000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no160000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no170000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no180000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no190000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
no200000 start is [[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]
[[ 3.15832966]
[ 12.52119579]
[ -4.42886451]]