普通的最小二乘法
from sklearn import linear_model
import numpy as np
from matplotlib import pyplot as plt
#生成列向量
x0 = np.random.normal(size=109).T
x1 = np.random.normal(size=109).T
x2 = np.random.normal(size=109).T
y = 2*x1 + 3*x2 + x0*9
#合并矩阵
X = np.column_stack((x0, x1,x2))
#拟合
reg = linear_model.LinearRegression()
reg.fit (X[:50],y[:50])
y_pred = reg.predict(X[50:])
print(reg.coef_)
岭回归
如果不用岭回归,效果如下
from sklearn import linear_model
import numpy as np
from matplotlib import pyplot as plt
x0 = np.random.normal(size=109).T
x1 = np.random.normal(size=109).T
x2 = x0 + x1
y = 2*x1 + 3*x2 + x0*6
#合并列向量
X = np.column_stack((x0, x1,x2))
reg = linear_model.LinearRegression()
reg.fit (X[:80],y[:80])
print(reg.coef_)
此时输出结果为[4.33333333 0.33333333 4.66666667],不符合预期
采用岭回归
from sklearn import linear_model
import numpy as np
from matplotlib import pyplot as plt
from sklearn import linear_model
import numpy as np
from matplotlib import pyplot as plt
x0 = np.random.normal(size=109).T
x1 = np.random.normal(size=109).T
x2 = x0 + x1
y = 2*x1 + 3*x2 + x0*9
X = np.column_stack((x0, x1,x2))
reg = reg = linear_model.RidgeCV(alphas=[0.0001, 1.0, 100.0])
reg.fit (X[:80],y[:80])
print(reg.coef_)
print(reg.intercept_)
此时输出结果为[ 6.33332712 -0.66666281 5.66666431],其实也没多符合。
Logistics回归
from sklearn.linear_model import LogisticRegression
from sklearn import metrics
import numpy as np
from matplotlib import pyplot as plt
X0 = [
[-1,1],
[-1,2],
[-2,3],
[-0.5,6],
[-5,2],
[-10,1],
[-19,2,],
[-5,-1],
[-6,-5],
[-7,-8],
[-2,-11],
[2,2],
[3,9],
[4,1],
[5,1],
[6,5],
[4,3],
[2,6],
[3,-1],
[6,-2],
[6,-9],
[1,-2],
[1,-1],
[2,-3],
[3,-6],
[7,-5],
]
X = np.zeros(26)
Y = np.zeros(26)
for i in range(26):
X[i]=X0[i][0]
Y[i]=X0[i][1]
y = [
0,0,0,0,0,0,0,
1,1,1,1,
2,2,2,2,2,2,2,
3,3,3,3,3,3,3,3
]
reg = LogisticRegression()
reg.fit(X0,y)
X_test = [[1,1],[-1,-1],[1,-1],[-1,1]]
print(reg.predict(X_test))
plt.figure(1)
ax = plt.axes()
ax.set_xlim(-10,10)
ax.set_ylim(-10,10)
plt.scatter(X[:7],Y[:7])
plt.scatter(X[7:11],Y[7:11])
plt.scatter(X[11:18],Y[11:18])
plt.scatter(X[18:26],Y[18:26])
plt.show()