1、加载数据集
import numpy as np
import matplotlib.pyplot as plt
class moon_data_class(object):
def __init__(self,N,d,r,w):
self.N=N
self.w=w
self.d=d
self.r=r
def sgn(self,x):
if(x>0):
return 1;
else:
return -1;
def sig(self,x):
return 1.0/(1+np.exp(x))
def dbmoon(self):
N1 = 10*self.N
r = self.r
w2 = self.w/2
d = self.d
done = True
data = np.empty(0)
while done:
#generate Rectangular data
tmp_x = 2*(r+w2)*(np.random.random([N1, 1])-0.5)
tmp_y = (r+w2)*np.random.random([N1, 1])
tmp = np.concatenate((tmp_x, tmp_y), axis=1)
tmp_ds = np.sqrt(tmp_x*tmp_x + tmp_y*tmp_y)
#generate double moon data ---upper
idx = np.logical_and(tmp_ds > (r-w2), tmp_ds < (r+w2))
idx = (idx.nonzero())[0]
if data.shape[0] == 0:
data = tmp.take(idx, axis=0)
else:
data = np.concatenate((data, tmp.take(idx, axis=0)), axis=0)
if data.shape[0] >= N:
done = False
#print (data)
db_moon = data[0:N, :]
#print (db_moon)
#generate double moon data ----down
data_t = np.empty([N, 2])
data_t[:, 0] = data[0:N, 0] + r
data_t[:, 1] = -data[0:N, 1] - d
db_moon = np.concatenate((db_moon, data_t), axis=0)
return db_moon
N = 100
d = 1
r = 10
width = 6
data_source = moon_data_class(N, d, r, width)
data = data_source.dbmoon()
a = 0.001
num_MSE = []
num_step = []
x0 = [1 for x in range(1,201)]
x = np.array([np.reshape(data[0:2*N, 0], len(data)), np.reshape(data[0:2*N, 1], len(data))]).transpose()
w = np.array([ 0, 0])
d_pre = [1 for y in range(1, 101)]
d_pos = [-1 for y in range(1, 101)]
d=d_pre+d_pos
2、利用最小二乘法进行计算
公式:
B
=
(
X
T
X
)
−
1
X
T
Y
B = (X^TX)^{-1}X^TY
B=(XTX)−1XTY
XT = x.T
B=np.dot(np.dot(np.linalg.inv(np.dot(XT,x)),XT),d)
3、打印运算结果
x = np.array(range(-15, 25))
y = -x*B[0]/B[1]
plt.plot(x, y, 'g--')
plt.plot(data[0:N, 0], data[0:N, 1], 'r*', data[N:2*N, 0], data[N:2*N, 1], 'b*')
plt.show()
4、运行结果