python实现总体最小二乘(TLS)

用python实现总体最小二乘

导入库,读取数据(数据网址为点击打开链接

import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
data=pd.read_table("/Users/cofreewy.txt")
x1=data['Traffic']
y1=data['CO']
数据归一化

from sklearn.preprocessing import scale
x1= scale(x1)
y1=scale(y1)
(1)拟合普通最小二乘 -- OLS

from scipy.optimize import leastsq
p0=[1,20]
def func(p,x):
    k,b=p
    return k*x+b
def error(p,x,y):
    return func(p,x)-y
Para=leastsq(error,p0,args=(x1,y1))
#读取结果
#coding=utf-8
k,b=Para[0]
print("k=",k,"b=",b)
print("cost:"+str(Para[1]))
print("拟合直线为:")
print("y="+str(round(k,4))+"x"+str(round(b,4))

#输出的OLS结果

 ('k=', 0.9626654614422958, 'b=', -4.1450221033301204e-10)

  cost:1
 拟合直线为:
  y=0.9627x+-0.0

画散点图及拟合直线

#画样本点
plt.scatter(x1,y1,color="green",label="label_num",linewidth=1)
#画拟合直线
y_ols=k*x1+b ##函数式
plt.plot(x1,y_ols,color="red",label="ols",linewidth=1) 
plt.legend(loc='lower right') #绘制图例
plt.show()


(2)定义总体最小二乘 -- TLS

# Total Least Squares:
def line_total_least_squares(x,y):
    n = len(x)
    
    x_m = np.sum(x)/n
    y_m = np.sum(y)/n
    
    # Calculate the x~ and y~ 
    x1 = x - x_m
    y1 = y - y_m
    
    # Create the matrix array
    X = np.vstack((x1, y1))
    X_t = np.transpose(X)
    
    # Finding A_T_A and it's Find smallest eigenvalue::
    prd = np.dot(X,X_t)
    W,V = np.linalg.eig(prd)
    small_eig_index = W.argmin()
    a,b = V[:,small_eig_index] 
    
    # Compute C:
    c = (-1*a*x_m) + (-1*b*y_m)
    
    return a,b,c
打印结果

# Total Least Squares:
print ('Training................................')
a1,b1,c1 = line_total_least_squares(x1,y1)
print ('Training Complete')
print 'a = ', a1
print 'b = ', b1
print 'c = ', c1
print("总体最小二乘法拟合直线为:")
print("y="+str(round(-a1,4))+"x"+str(round(-b1,4)))
#--TLS--输出拟合直线结果为
Training Complete
a =  -0.7071067811865474
b =  0.7071067811865477
c =  1.242989863124656e-16
总体最小二乘法拟合直线为:
y=0.7071x-0.7071
 绘出普通最小二乘和总体最小二乘拟合直线 
plt.figure(figsize=(6,6))
y_fitted = -1*(a1/b1)*x1 + (-1*(c1/b1))
plt.scatter(x1,y1,color="green",label="label_num",linewidth=1)
plt.plot(x1,y_fitted,color="red",label="LST",linewidth=1)
plt.plot(x1,y_ols,color="y",label="ols",linewidth=1) 
plt.legend(loc='lower right')
plt.show()

拟合图结果


评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值