# Pandas画图
import itertools
from math import exp
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
def lasso_regression(X_array, y, lambd, threshold=0.1):
#通过坐标下降(coordinate descent)法获取LASSO回归系数
# 计算残差平方和
X=np.column_stack((np.ones( X_array.shape[0]),X_array))
square_diff_sum = lambda X, y, w: np.dot((np.dot(X,w)-y).T,(np.dot(X,w)-y))
#数据的初始化,参数是0数组,差方的初始化通过w的0数组来实现的
m, n = X.shape
w = np.zeros(n)#创造承载器
r = square_diff_sum(X, y, w)
# 迭代器
itertoll = itertools.count(1)
for it in itertoll:
for k in range(n):
#print(k)
z_k = np.dot(X[:,k].T,X[:,k])
#print(z_k)
p_k = 0
for i in range(m):
p_k += X[i,k]*(y[i] - sum([X[i,j]*w[j] for j in range(n) if j != k]))
#print(p_k,lambd/2)
if p_k < -lambd/2:
w_k = (p_k + lambd/2)/z_k
elif p_k > lambd/2:
w_k = (p_k - lambd/2)/z_k
else:
w_k = 0
w[k] = w_k
square_diff_sum_indirect = square_diff_sum(X, y, w)
error_abs = abs(square_diff_sum_indirect- r)#绝对差别
r = square_diff_sum_indirect
if error_abs< threshold:
break
return w
def lasso_trace(X, Y, lambda_1=30):
list_tuple=[]
for j in range(lambda_1):
coff_vector=lasso_regression(X,Y,exp(j-15),threshold=0.1)
list_tuple.append((coff_vector,exp(j-15)))
return list_tuple
X_data=pd.read_csv(r'Desktop\data_new.csv',header=None,delimiter=',').values
data_1=lasso_trace(X_data[:200,:8],X_data[:200,-1],lambda_1=30)
Data_list=[]
Datalist_ind=[]
for i in range(len(data_1)):
Datalist_ind.append(data_1[i][1])
Datalist_ind_ar=np.array(Datalist_ind)
for i in range(len(data_1)):
Data_list.append(data_1[i][0])
Data_list_ar=np.array(Data_list)
Data_list_ar
Out[4]:
array([[ 2.79348445, 0.15492015, 10.85218844, 2.45052834,
-5.5520772 , 1.78585009, -1.05854674, -6.71211553,
9.08889764],
[ 2.79348427, 0.15492015, 10.85218773, 2.45052801,
-5.55207116, 1.78584959, -1.05854605, -6.71211511, 9.0888975 ],
[ 2.79348377, 0.15492017, 10.85218578, 2.45052713,
-5.55205472, 1.78584823, -1.05854418, -6.71211397,
9.08889713],
[ 2.79348242, 0.15492021, 10.85218047, 2.45052474,
-5.55201004, 1.78584453, -1.0585391 , -6.71211086, 9.0888961 ],
[ 2.79347874, 0.15492032, 10.85216605, 2.45051825,
-5.5518886 , 1.78583449, -1.05852529, -6.7121024 ,
9.08889332],
[ 2.79346874, 0.15492062, 10.85212685, 2.45050059,
-5.55155847, 1.78580718, -1.05848774, -6.71207943,
9.08888576],
[ 2.79344156, 0.15492144, 10.8520203 , 2.45045258,
-5.5506611 , 1.78573296, -1.05838569, -6.71201697, 9.0888652 ],
[ 2.79336767, 0.15492366, 10.85173066, 2.45032209,
-5.54822178, 1.78553121, -1.05810827, -6.71184718,
9.08880931],
[ 2.79316683, 0.15492969, 10.85094334, 2.44996739,
-5.54159103, 1.78498277, -1.05735418, -6.71138567, 9.0886574 ],
[ 2.79262087, 0.1549461 , 10.84880317, 2.44900321,
-5.52356679, 1.78349198, -1.05530434, -6.71013114,
9.08824447],
[ 2.79113682, 0.15499069, 10.8429856 , 2.44638228,
-5.47457181, 1.77943958, -1.04973231, -6.70672099, 9.087122 ],
[ 2.78710274, 0.1551119 , 10.8271718 , 2.43925788,
-5.34138966, 1.76842402, -1.03458595, -6.69745123,
9.08407082],
[ 2.77606227, 0.15544489, 10.78335654, 2.41959679,
-4.97418339, 1.73859301, -0.9933354 , -6.6729087 ,
9.07541591],
[ 2.75855702, 0.1563386 , 10.63624462, 2.34968765,
-3.96024738, 1.66625799, -0.88454292, -6.59694821,
9.04751813],
[ 2.72584138, 0.1587734 , 10.20535513, 2.14333551,
-1.20391696, 1.48686962, -0.60195674, -6.38253862, 8.9589362 ],
[ 2.98511488, 0.16074567, 9.29830451, 1.74278896,
0. , 1.3390553 , 0. , -5.16552212,
8.63704446],
[ 3.49126616, 0.17581216, 7.98176878, 1.13999794,
0. , 0.86268062, 0. , -0.36535168,
8.44577302],
[ 4.22098008, 0.16342411, 6.5460154 , 0. ,
0. , 2.26203814, 0. , 0. ,
5.63922567],
[ 4.89821546, 0.12996178, 4.3885469 , 0. ,
0. , 4.60855106, 0. , 0. , 0. ],
[ 6.26532372, 0.01755422, 0. , 0. ,
0. , 5.57462636, 0. , 0. , 0. ],
[ 6.32781147, 0. , 0. , 0. ,
0. , 5.23442736, 0. , 0. , 0. ],
[ 6.08620272, 0. , 0. , 0. ,
0. , 4.69472192, 0. , 0. , 0. ],
[ 5.64711907, 0. , 0. , 0. ,
0. , 3.02514678, 0. , 0. , 0. ],
[ 3.28260503, 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ],
[ 0. , 0. , 0. , 0. ,
0. , 0. , 0. , 0. , 0. ]])
comnine_ar=np.column_stack((Datalist_ind_ar,Data_list_ar))
comnine_ar_frame=pd.DataFrame(comnine_ar[:,1:],index=list(np.log(Datalist_ind_ar)),columns=['a1','a2','a3','a4','a5','a6','a7','a8','a9'])
comnine_ar_frame
Out[6]:
a1 a2 a3 a4 a5 a6 a7 \
-15.0 2.793484 0.154920 10.852188 2.450528 -5.552077 1.785850 -1.058547
-14.0 2.793484 0.154920 10.852188 2.450528 -5.552071 1.785850 -1.058546
-13.0 2.793484 0.154920 10.852186 2.450527 -5.552055 1.785848 -1.058544
-12.0 2.793482 0.154920 10.852180 2.450525 -5.552010 1.785845 -1.058539
-11.0 2.793479 0.154920 10.852166 2.450518 -5.551889 1.785834 -1.058525
-10.0 2.793469 0.154921 10.852127 2.450501 -5.551558 1.785807 -1.058488
-9.0 2.793442 0.154921 10.852020 2.450453 -5.550661 1.785733 -1.058386
-8.0 2.793368 0.154924 10.851731 2.450322 -5.548222 1.785531 -1.058108
-7.0 2.793167 0.154930 10.850943 2.449967 -5.541591 1.784983 -1.057354
-6.0 2.792621 0.154946 10.848803 2.449003 -5.523567 1.783492 -1.055304
-5.0 2.791137 0.154991 10.842986 2.446382 -5.474572 1.779440 -1.049732
-4.0 2.787103 0.155112 10.827172 2.439258 -5.341390 1.768424 -1.034586
-3.0 2.776062 0.155445 10.783357 2.419597 -4.974183 1.738593 -0.993335
-2.0 2.758557 0.156339 10.636245 2.349688 -3.960247 1.666258 -0.884543
-1.0 2.725841 0.158773 10.205355 2.143336 -1.203917 1.486870 -0.601957
0.0 2.985115 0.160746 9.298305 1.742789 0.000000 1.339055 0.000000
1.0 3.491266 0.175812 7.981769 1.139998 0.000000 0.862681 0.000000
2.0 4.220980 0.163424 6.546015 0.000000 0.000000 2.262038 0.000000
3.0 4.898215 0.129962 4.388547 0.000000 0.000000 4.608551 0.000000
4.0 6.265324 0.017554 0.000000 0.000000 0.000000 5.574626 0.000000
5.0 6.327811 0.000000 0.000000 0.000000 0.000000 5.234427 0.000000
6.0 6.086203 0.000000 0.000000 0.000000 0.000000 4.694722 0.000000
7.0 5.647119 0.000000 0.000000 0.000000 0.000000 3.025147 0.000000
8.0 3.282605 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
9.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
10.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
11.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
12.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
13.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
14.0 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000
a8 a9
-15.0 -6.712116 9.088898
-14.0 -6.712115 9.088898
-13.0 -6.712114 9.088897
-12.0 -6.712111 9.088896
-11.0 -6.712102 9.088893
-10.0 -6.712079 9.088886
-9.0 -6.712017 9.088865
-8.0 -6.711847 9.088809
-7.0 -6.711386 9.088657
-6.0 -6.710131 9.088244
-5.0 -6.706721 9.087122
-4.0 -6.697451 9.084071
-3.0 -6.672909 9.075416
-2.0 -6.596948 9.047518
-1.0 -6.382539 8.958936
0.0 -5.165522 8.637044
1.0 -0.365352 8.445773
2.0 0.000000 5.639226
3.0 0.000000 0.000000
4.0 0.000000 0.000000
5.0 0.000000 0.000000
6.0 0.000000 0.000000
7.0 0.000000 0.000000
8.0 0.000000 0.000000
9.0 0.000000 0.000000
10.0 0.000000 0.000000
11.0 0.000000 0.000000
12.0 0.000000 0.000000
13.0 0.000000 0.000000
14.0 0.000000 0.000000
sdf,pis=plt.subplots(dpi=200)
comnine_ar_frame.plot(ax=pis)
plt.show()
兄弟连学python
Python学习交流、资源共享群:563626388 QQ