十一、加权线性回归案例:预测鲍鱼的年龄

加权线性回归案例:预测鲍鱼的年龄

点击文章标题即可获取源代码和笔记
数据集:https://download.csdn.net/download/weixin_44827418/12553408

1.导入数据集

数据集描述:
在这里插入图片描述

import pandas as pd
import numpy as np

abalone = pd.read_table("./datas/abalone.txt",header=None)
abalone.columns=['性别','长度','直径','高度','整体重量','肉重量','内脏重量','壳重','年龄']
abalone.head()
性别长度直径高度整体重量肉重量内脏重量壳重年龄
010.4550.3650.0950.51400.22450.10100.15015
110.3500.2650.0900.22550.09950.04850.0707
2-10.5300.4200.1350.67700.25650.14150.2109
310.4400.3650.1250.51600.21550.11400.15510
400.3300.2550.0800.20500.08950.03950.0557
abalone.shape
(4177, 9)
abalone.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 4177 entries, 0 to 4176
Data columns (total 9 columns):
 #   Column  Non-Null Count  Dtype  
---  ------  --------------  -----  
 0   性别      4177 non-null   int64  
 1   长度      4177 non-null   float64
 2   直径      4177 non-null   float64
 3   高度      4177 non-null   float64
 4   整体重量    4177 non-null   float64
 5   肉重量     4177 non-null   float64
 6   内脏重量    4177 non-null   float64
 7   壳重      4177 non-null   float64
 8   年龄      4177 non-null   int64  
dtypes: float64(7), int64(2)
memory usage: 293.8 KB
abalone.describe()
性别长度直径高度整体重量肉重量内脏重量壳重年龄
count4177.0000004177.0000004177.0000004177.0000004177.0000004177.0000004177.0000004177.0000004177.000000
mean0.0529090.5239920.4078810.1395160.8287420.3593670.1805940.2388319.933684
std0.8222400.1200930.0992400.0418270.4903890.2219630.1096140.1392033.224169
min-1.0000000.0750000.0550000.0000000.0020000.0010000.0005000.0015001.000000
25%-1.0000000.4500000.3500000.1150000.4415000.1860000.0935000.1300008.000000
50%0.0000000.5450000.4250000.1400000.7995000.3360000.1710000.2340009.000000
75%1.0000000.6150000.4800000.1650001.1530000.5020000.2530000.32900011.000000
max1.0000000.8150000.6500001.1300002.8255001.4880000.7600001.00500029.000000

2. 查看数据分布状况

import numpy as np
import pandas as pd
import random
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif']=['simhei'] #显示中文
plt.rcParams['axes.unicode_minus']=False # 用来正常显示负号  
%matplotlib inline
mpl.cm.rainbow(np.linspace(0,1,10))
array([[5.00000000e-01, 0.00000000e+00, 1.00000000e+00, 1.00000000e+00],
       [2.80392157e-01, 3.38158275e-01, 9.85162233e-01, 1.00000000e+00],
       [6.07843137e-02, 6.36474236e-01, 9.41089253e-01, 1.00000000e+00],
       [1.66666667e-01, 8.66025404e-01, 8.66025404e-01, 1.00000000e+00],
       [3.86274510e-01, 9.84086337e-01, 7.67362681e-01, 1.00000000e+00],
       [6.13725490e-01, 9.84086337e-01, 6.41213315e-01, 1.00000000e+00],
       [8.33333333e-01, 8.66025404e-01, 5.00000000e-01, 1.00000000e+00],
       [1.00000000e+00, 6.36474236e-01, 3.38158275e-01, 1.00000000e+00],
       [1.00000000e+00, 3.38158275e-01, 1.71625679e-01, 1.00000000e+00],
       [1.00000000e+00, 1.22464680e-16, 6.12323400e-17, 1.00000000e+00]])
mpl.cm.rainbow(np.linspace(0,1,10))[0]
array([0.5, 0. , 1. , 1. ])
def dataPlot(dataSet):
    m,n = dataSet.shape
    fig = plt.figure(figsize=(8,20),dpi=100)
    colormap = mpl.cm.rainbow(np.linspace(0,1,n))
    for i in range(n):
        fig_ = fig.add_subplot(n,1,i+1)
        plt.scatter(range(m),dataSet.iloc[:,i].values,s=2,c=colormap[i])
        plt.title(dataSet.columns[i])
        plt.tight_layout(pad=1.2) # 调节子图间的距离
# 运行函数,查看数据分布:
dataPlot(abalone)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sUDRrFEr-1593153198969)(output_10_1.png)]

可以从数据分布散点图中看出:

1)除“性别”之外,其他数据明显存在规律性排列

2)“高度”这一特征中,有两个异常值

从看到的现象,我们可以采取以下两种措施:

1) 切分训练集和测试集时,需要打乱原始数据集来进行随机挑选

2) 剔除"高度"这一特征中的异常值

abalone['高度']<0.4
0       True
1       True
2       True
3       True
4       True
        ... 
4172    True
4173    True
4174    True
4175    True
4176    True
Name: 高度, Length: 4177, dtype: bool
aba = abalone.loc[abalone['高度']<0.4,:]
#再次查看数据集的分布
dataPlot(aba)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-rhcXvPsH-1593153198971)(output_18_1.png)]

2. 切分训练集和测试集

"""
函数功能:随机切分训练集和测试集
参数说明:
    dataSet:原始数据集
    rate:训练集比例
返回:
    train,test:切分好的训练集和测试集
"""
def randSplit(dataSet,rate):
    l = list(dataSet.index) # 将原始数据集的索引提取出来,存到列表中
    random.seed(123) # 设置随机数种子
    random.shuffle(l) # 随机打乱数据集中的索引
    dataSet.index = l # 把打乱后的索引重新赋值给数据集中的索引,
    # 索引打乱了就相当于打乱了原始数据集中的数据
    m = dataSet.shape[0] # 原始数据集样本总数
    n = int(m*rate) # 训练集样本数量
    train = dataSet.loc[range(n),:] # 从打乱了的原始数据集中提取出训练集数据
    test = dataSet.loc[range(n,m),:] # 从打乱了的原始数据集中提取出测试集数据
    train.index = range(train.shape[0]) # 重置train训练数据集中的索引
    test.index = range(test.shape[0]) # 重置test测试数据集中的索引
    dataSet.index = range(dataSet.shape[0]) # 重置原始数据集中的索引
    return train,test
train,test = randSplit(aba,0.8)
#探索训练集
train.head()
性别长度直径高度整体重量肉重量内脏重量壳重年龄
0-10.5900.4700.1700.90000.35500.19050.250011
110.5600.4500.1450.93550.42500.16450.272511
2-10.6350.5350.1901.24200.57600.24750.390014
310.5050.3900.1150.55850.25750.11900.15358
410.5100.4100.1450.79600.38650.18150.19558
train.shape
(3340, 9)
abalone.describe()
性别长度直径高度整体重量肉重量内脏重量壳重年龄
count4177.0000004177.0000004177.0000004177.0000004177.0000004177.0000004177.0000004177.0000004177.000000
mean0.0529090.5239920.4078810.1395160.8287420.3593670.1805940.2388319.933684
std0.8222400.1200930.0992400.0418270.4903890.2219630.1096140.1392033.224169
min-1.0000000.0750000.0550000.0000000.0020000.0010000.0005000.0015001.000000
25%-1.0000000.4500000.3500000.1150000.4415000.1860000.0935000.1300008.000000
50%0.0000000.5450000.4250000.1400000.7995000.3360000.1710000.2340009.000000
75%1.0000000.6150000.4800000.1650001.1530000.5020000.2530000.32900011.000000
max1.0000000.8150000.6500001.1300002.8255001.4880000.7600001.00500029.000000
train.describe() #统计描述
性别长度直径高度整体重量肉重量内脏重量壳重年龄
count3340.0000003340.0000003340.0000003340.0000003340.0000003340.0000003340.0000003340.0000003340.000000
mean0.0604790.5227540.4068860.1387900.8249060.3581510.1797320.2371589.911976
std0.8190210.1203000.0993720.0384410.4885350.2224220.1090360.1379203.223534
min-1.0000000.0750000.0550000.0000000.0020000.0010000.0005000.0015001.000000
25%-1.0000000.4500000.3500000.1150000.4390000.1843750.0920000.1300008.000000
50%0.0000000.5400000.4200000.1400000.7967500.3355000.1710000.2320009.000000
75%1.0000000.6150000.4800000.1650001.1472500.4985000.2505000.32500011.000000
max1.0000000.7800000.6300000.2500002.8255001.4880000.7600001.00500027.000000
dataPlot(train) #查看训练集数据分布
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sIC8Ac3y-1593153198972)(output_26_1.png)]

#探索测试集
test.head() 
性别长度直径高度整体重量肉重量内脏重量壳重年龄
010.6300.4700.1501.13550.53900.23250.311512
1-10.5850.4450.1400.91300.43050.22050.253010
2-10.3900.2900.1250.30550.12100.08200.09007
310.5250.4100.1300.99000.38650.24300.295015
410.6250.4750.1601.08450.50050.23550.310510
test.shape 
(835, 9)
test.describe() 
性别长度直径高度整体重量肉重量内脏重量壳重年龄
count835.000000835.000000835.000000835.000000835.000000835.000000835.000000835.000000835.000000
mean0.0227540.5288080.4117370.1407840.8427140.3633700.1837490.24532010.022754
std0.8343410.1191660.0986270.0386640.4959900.2189380.1115100.1439253.230284
min-1.0000000.1300000.1000000.0150000.0130000.0045000.0030000.0040003.000000
25%-1.0000000.4500000.3500000.1150000.4580000.1920000.0965000.1327508.000000
50%0.0000000.5500000.4300000.1400000.8100000.3390000.1705000.23500010.000000
75%1.0000000.6200000.4850000.1700001.1772500.5107500.2592500.33700011.000000
max1.0000000.8150000.6500000.2500002.5550001.1455000.5900000.81500029.000000
dataPlot(test)
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.
'c' argument looks like a single numeric RGB or RGBA sequence, which should be avoided as value-mapping will have precedence in case its length matches with 'x' & 'y'.  Please use a 2-D array with a single row if you really want to specify the same RGB or RGBA value for all points.

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-MjIwyXmw-1593153198974)(output_30_1.png)]

3.构建辅助函数

'''
函数功能:输入DF数据集(最后一列为标签),返回特征矩阵和标签矩阵
'''
def get_Mat(dataSet):
    xMat = np.mat(dataSet.iloc[:,:-1].values)
    yMat = np.mat(dataSet.iloc[:,-1].values).T
    return xMat,yMat

'''
函数功能:数据集可视化
'''
def plotShow(dataSet):
    xMat,yMat = get_Mat(dataSet)
    plt.scatter(xMat.A[:,1],yMat.A,c='b',s=5)
    plt.show()

'''
函数功能:计算回归系数
参数说明:
    dataSet:原始数据集
返回:
    ws:回归系数
'''
def standRegres(dataSet):
    xMat,yMat = get_Mat(dataSet)
    xTx = xMat.T * xMat
    if np.linalg.det(xTx) == 0:
        print('矩阵为奇异矩阵,无法求逆!')
        return
    ws = xTx.I*(xMat.T*yMat) # xTx.I ,用来求逆矩阵
    return ws
"""
函数功能:计算误差平方和SSE
参数说明:
    dataSet:真实值
    regres:求回归系数的函数
返回:
    SSE:误差平方和
"""
def sseCal(dataSet, regres):
    xMat,yMat = get_Mat(dataSet)
    ws = regres(dataSet)
    yHat = xMat*ws
    sse = ((yMat.A.flatten() - yHat.A.flatten())**2).sum()#  
    return sse

以ex0数据集为例,查看函数运行结果:

ex0 = pd.read_table("./datas/ex0.txt",header=None)
ex0.head()
012
01.00.0677323.176513
11.00.4278103.816464
21.00.9957314.550095
31.00.7383364.256571
41.00.9810834.560815
#简单线性回归的SSE
sseCal(ex0, standRegres)
1.3552490816814902

构建相关系数R2计算函数

"""
函数功能:计算相关系数R2
"""
def rSquare(dataSet,regres):
    xMat,yMat=get_Mat(dataSet)
    sse = sseCal(dataSet,regres)
    sst = ((yMat.A-yMat.mean())**2).sum()#  
    r2 = 1 - sse / sst
    return r2

同样以ex0数据集为例,查看函数运行结果:

#简单线性回归的R2
rSquare(ex0, standRegres)
0.9731300889856916
'''
函数功能:计算局部加权线性回归的预测值
参数说明:
    testMat:测试集
    xMat:训练集的特征矩阵
    yMat:训练集的标签矩阵
    返回:
        yHat:函数预测值
'''
def LWLR(testMat,xMat,yMat,k=1.0):
    n = testMat.shape[0] # 测试数据集行数
    m = xMat.shape[0] # 训练集特征矩阵行数
    weights = np.mat(np.eye(m)) # 用单位矩阵来初始化权重矩阵,
    yHat = np.zeros(n) # 用0矩阵来初始化预测值矩阵
    for i in range(n):
        for j in range(m):
            diffMat = testMat[i] - xMat[j]
            weights[j,j] = np.exp(diffMat*diffMat.T / (-2*k**2))
        xTx = xMat.T*(weights*xMat)
        if np.linalg.det(xTx) == 0:
            print('矩阵为奇异矩阵,无法求逆')
            return
        ws = xTx.I*(xMat.T*(weights*yMat))
        yHat[i] = testMat[i] * ws
    return ws,yHat

4.构建加权线性模型

因为数据量太大,计算速度极慢,所以此处选择训练集的前100个数据作为训练集,测试集的前100个数据作为测试集。

"""
函数功能:绘制不同k取值下,训练集和测试集的SSE曲线
"""
def ssePlot(train,test):
    X0,Y0 = get_Mat(train)
    X1,Y1 =get_Mat(test)
    train_sse = []
    test_sse = []
    for k in np.arange(0.2,10,0.5):
        ws1,yHat1 = LWLR(X0[:99],X0[:99],Y0[:99],k) 
        sse1 = ((Y0[:99].A.T - yHat1)**2).sum() 
        train_sse.append(sse1)
        
        ws2,yHat2 = LWLR(X1[:99],X0[:99],Y0[:99],k) 
        sse2 = ((Y1[:99].A.T - yHat2)**2).sum() 
        test_sse.append(sse2)
        
    plt.figure(figsize=(20,8),dpi=100)
    plt.plot(np.arange(0.2,10,0.5),train_sse,color='b')#     
    plt.plot(np.arange(0.2,10,0.5),test_sse,color='r') 
    plt.xlabel('不同k取值')
    plt.ylabel('SSE')
    plt.legend(['train_sse','test_sse'])

运行结果:

ssePlot(train,test)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-BXGhyRcs-1593153198975)(output_47_0.png)]

这个图的解读应该是这样的:从右往左看,当K取较大值时,模型比较稳定,随着K值的减小,训练集的SSE开始逐渐减小,当K取到2左右,训练集的SSE与测试集的SSE相等,当K继续减小时,训练集的SSE也越来越小,也就是说,模型在训练集上的表现越来越好,但是,模型在测试集上的表现却越来越差了,这就说明模型开始出现过拟合了。其实,这个图与前面不同k值的结果图是吻合的,K=1.0,
0.01, 0.003这三张图也表明随着K的减小,模型会逐渐出现过拟合。所以这里可以看出,K在2左右的取值最佳。

我们再将K=2带入局部线性回归模型中,然后查看预测结果:

train,test = randSplit(aba,0.8) # 随机切分原始数据集,得到训练集和测试集
trainX,trainY = get_Mat(train) # 将切分好的训练集分成特征矩阵和标签矩阵
testX,testY = get_Mat(test) # 将切分好的测试集分成特征矩阵和标签矩阵
ws0,yHat0 = LWLR(testX,trainX,trainY,k=2)

绘制真实值与预测值之间的关系图

y=testY.A.flatten()
plt.scatter(y,yHat0,c='b',s=5); # ;等效于plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-y9Wfstwl-1593153198976)(output_52_0.png)]

通过上图可知,横坐标为真实值,纵坐标为预测值,形成的图像为呈现一个“喇叭形”,随着横坐标真实值逐渐变大,纵坐标预测值也越来越大,说明随着真实值的增加,预测值偏差越来越大

封装一个函数来计算SSE和R方,方便后续调用

"""
函数功能:计算加权线性回归的SSE和R方
"""
def LWLR_pre(dataSet):
    train,test = randSplit(dataSet,0.8)#      
    trainX,trainY = get_Mat(train)
    testX,testY = get_Mat(test)
    ws,yHat = LWLR(testX,trainX,trainY,k=2)#     
    sse = ((testY.A.T - yHat)**2).sum()#     
    sst = ((testY.A-testY.mean())**2).sum() #     
    r2 = 1 - sse / sst
    return sse,r2

查看模型预测结果

LWLR_pre(aba)
(4152.777097646255, 0.5228101340130846)

从结果可以看出,SSE达4000+,相关系数只有0.52,模型效果并不是很好。

  • 2
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
您可以使用局部加权线性回归(Locally Weighted Linear Regression,简称LWLR)来预测鲍鱼年龄。下面是一个使用MATLAB实现的案例示例: ```matlab % 加载数据 data = load('abalone.csv'); X = data(:, 1:end-1); y = data(:, end); % 定义局部加权线性回归函数 function theta = lwlr(X_train, y_train, x, tau) m = size(X_train, 1); weights = exp(-sum((X_train - x).^2, 2) / (2 * tau^2)); % 计算权重 W = diag(weights); % 构建权重矩阵 X_train = [ones(m, 1), X_train]; % 添加偏置项 theta = pinv(X_train' * W * X_train) * X_train' * W * y_train; % 计算参数 end % 设置参数 tau = 0.1; % tau值越小,考虑的样本越少,模型越复杂;tau值越大,考虑的样本越多,模型越简单 % 预测鲍鱼年龄 x = [0.455, 0.365, 0.095, 0.514, 0.224, 0.101, 0.15]; % 样本特征值 y_pred = lwlr(X, y, x, tau); % 预测年龄 disp(['预测年龄:', num2str(y_pred)]); ``` 上述代码中,首先从`abalone.csv`文件中加载数据,其中最后一列是鲍鱼年龄,其余列是鲍鱼的特征。然后定义了一个`lwlr`函数来实现局部加权线性回归,该函数接受训练集特征值`X_train`、训练集标签值`y_train`、待预测样本特征值`x`以及`tau`值作为输入,返回预测年龄值。在`lwlr`函数内部,首先计算出每个训练样本与待预测样本之间的距离,然后根据距离计算出权重,构建权重矩阵,并利用加权最小二乘法计算出模型参数。最后,使用预测函数对指定样本特征值进行预测,并输出预测结果。 请注意,上述代码仅为示例,实际应用中可能需要对数据进行预处理、设置适当的`tau`值以及进行模型评估等。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值