NBA球员得分预测-基于线性回归、KNN回归、决策树回归、随机森林回归

前言

在NBA中,预测每个球员的得分在篮球分析领域至关重要。它是一个关键的表现指标,允许教练、分析师和球迷评估球员的得分能力和对球队的整体进攻贡献。了解球员的得分潜力有助于比赛中的战略决策、球员选择和人才发掘。在本篇报告中,我们深入研究了篮球数据分析领域并使用机器学习技术来预测每个球员的得分水平。

预测所采用的回归模型:

  • 线性回归
  • KNN回归器
  • 决策树回归器
  • 随机森林回归器

通过使用这些回归模型,旨在了解它们在预测球员得分方面的表现,并比较各自的预测能力。通过对比分析,可以从实际意义上考量不同模型各自的优劣,并在这个特定的数据集中确定最适合预测球员得分的模型。

一 数据集概述

2023_nba_player_stats.csv
在该数据集中,包含2023年所有NBA球员的各项指标数据。其中各列名简称的实际解释意义如下:

PNamePosTeamAgeGPW
球员姓名球员位置所属球队年龄出场次数胜场
LMinPTSFGMFGAFG%
负场出场时间总得分投篮命中数投篮总次数投篮命中率
3PM3PA3P%FTMFTAFT%
三分命中数三分出手数三分命中率罚球命中数罚球总次数罚球命中率
OREBDREBREBASTTOVSTL
进攻篮板数防守篮板数总篮板数总助攻数总失误数总抢断数
BLKPFFPDD2TD3+/-
总盖帽数个人犯规数虚拟得分两双数三双数正负值总和

其中,球员虚拟得分(FP)指的是在NBA2K2023中进行模拟球队对局所产生的常规赛各球员得分总数。其余各项指标均为篮球基本术语,在此不过多解释。

二 导入库

在进行数据分析与处理的过程中,需要在pycharm编辑器中导入数据操作与可视化所需的库。

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots

from sklearn.model_selection import train_test_split, 
                                    GridSearchCV, 
                                    cross_val_score
from sklearn.metrics import classification_report, 
                            confusion_matrix, 
                            f1_score, r2_score

from sklearn.linear_model import LinearRegression
from sklearn.neighbors import KNeighborsRegressor
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import RandomForestRegressor

import warnings

warnings.filterwarnings("ignore")

众所周知,Python拥有强大的库资源。在本篇报告中,基于数据主要导入了pandas库用于数据分析与处理、numpy库用于数值计算、matplotlib库用于数据可视化呈现以及plotly库用于实现交互式数据可视化;基于模型则导入了linear_model、neighbors、tree、ensemble各自的回归器,分别用于实现线性回归、K最近邻回归、决策树回归以及随机森林回归。同时,导入warnings库用于在控制台忽视warnings信息。

三 读取数据集

3.1 读取数据

利用pandas库读取csv文件,读取后的数据类型为DataFrame类型。

# 读取nba球员数据为csv文件
df = pd.read_csv('E:\\数据文件\\2023_nba_player_stats.csv')

3.2 数据集探索

  1. 识别数据集的行数与列数
  2. 修改列名
  3. 加载数据集基本信息
  4. 描述性统计

# 数据集的行数和列数
row, col = df.shape
print("This Dataset have", row, "rows and", col, "columns.")
print("Number of duplicate data : ", df.duplicated().sum())
This Dataset have 539 rows and 30 columns.
Number of duplicate data :  0

数据集包含539行,30列,其中完全重复数据为0条。


df.rename(columns={
    'PName': 'Player_Name',
    'POS': 'Position',
    'Team': 'Team_Abbreviation',
    'Age': 'Age',
    'GP': 'Games_Played',
    'W': 'Wins',
    'L': 'Losses',
    'Min': 'Minutes_Played',
    'PTS': 'Total_Points',
    'FGM': 'Field_Goals_Made',
    'FGA': 'Field_Goals_Attempted',
    'FG%': 'Field_Goal_Percentage',
    '3PM': 'Three_Point_FG_Made',
    '3PA': 'Three_Point_FG_Attempted',
    '3P%': 'Three_Point_FG_Percentage',
    'FTM': 'Free_Throws_Made',
    'FTA': 'Free_Throws_Attempted',
    'FT%': 'Free_Throw_Percentage',
    'OREB': 'Offensive_Rebounds',
    'DREB': 'Defensive_Rebounds',
    'REB': 'Total_Rebounds',
    'AST': 'Assists',
    'TOV': 'Turnovers',
    'STL': 'Steals',
    'BLK': 'Blocks',
    'PF': 'Personal_Fouls',
    'FP': 'NBA_Fantasy_Points',
    'DD2': 'Double_Doubles',
    'TD3': 'Triple_Doubles',
    '+/-': 'Plus_Minus'
}, inplace=True)

将原始数据的列名缩写修改为全称。


df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 539 entries, 0 to 538
Data columns (total 30 columns):
 #   Column                     Non-Null Count  Dtype  
---  ------                     --------------  -----  
 0   Player_Name                539 non-null    object 
 1   Position                   534 non-null    object 
 2   Team_Abbreviation          539 non-null    object 
 3   Age                        539 non-null    int64  
 4   Games_Played               539 non-null    int64  
 5   Wins                       539 non-null    int64  
 6   Losses                     539 non-null    int64  
 7   Minutes_Played             539 non-null    float64
 8   Total_Points               539 non-null    int64  
 9   Field_Goals_Made           539 non-null    int64  
 10  Field_Goals_Attempted      539 non-null    int64  
 11  Field_Goal_Percentage      539 non-null    float64
 12  Three_Point_FG_Made        539 non-null    int64  
 13  Three_Point_FG_Attempted   539 non-null    int64  
 14  Three_Point_FG_Percentage  539 non-null    float64
 15  Free_Throws_Made           539 non-null    int64  
 16  Free_Throws_Attempted      539 non-null    int64  
 17  Free_Throw_Percentage      539 non-null    float64
 18  Offensive_Rebounds         539 non-null    int64  
 19  Defensive_Rebounds         539 non-null    int64  
 20  Total_Rebounds             539 non-null    int64  
 21  Assists                    539 non-null    int64  
 22  Turnovers                  539 non-null    int64  
 23  Steals                     539 non-null    int64  
 24  Blocks                     539 non-null    int64  
 25  Personal_Fouls             539 non-null    int64  
 26  NBA_Fantasy_Points         539 non-null    int64  
 27  Double_Doubles             539 non-null    int64  
 28  Triple_Doubles             539 non-null    int64  
 29  Plus_Minus                 539 non-null    int64  
dtypes: float64(4), int64(23), object(3)
memory usage: 126.5+ KB

从加载的数据集基本信息表发现,除Position列存在NaN型数据外,其余列数据项均完整。


print(df.describe(include=np.number))
print(df.describe(include='object'))
              Age  Games_Played  ...  Triple_Doubles  Plus_Minus
count  539.000000    539.000000  ...      539.000000  539.000000
mean    25.970315     48.040816  ...        0.220779    0.000000
std      4.315513     24.650686  ...        1.564432  148.223909
min     19.000000      1.000000  ...        0.000000 -642.000000
25%     23.000000     30.500000  ...        0.000000  -70.000000
50%     25.000000     54.000000  ...        0.000000   -7.000000
75%     29.000000     68.000000  ...        0.000000   57.000000
max     42.000000     83.000000  ...       29.000000  640.000000

[8 rows x 27 columns]
         Player_Name Position Team_Abbreviation
count            539      534               539
unique           539        7                30
top     Jayson Tatum       SG               DAL
freq               1       96                21

从描述性统计来看,出现频率最高的球员位置为SG(得分后卫),出现频率为96次,不同的球员位置包括7种,分别为PG、SG、SF、PF、C以及不明确的G与F,其中PG、SG属于G的划分,SF、PF则属于F的划分。NBA球员的平均年龄为26岁,最小的仅19岁。


3.3 数据可视化

在数据集探索过程中,发现Position列中存在NaN型数据,在描述性统计中发现SG为Position列中出现频率最高的一项,因此,考虑将缺失项修正为SG。

df['Position'].fillna('SG', inplace=True)

在此之后,可以考虑将数据按照球员位置分组进行可视化呈现。其中包括根据球员位置分组得到的平均总得分、球员年龄的频数分布直方图、按位置分组下球员年龄与总得分、投篮命中率、总助攻的二维关系散点图等。


position_stats = df.groupby(['Position']).agg({'Total_Points': 'mean'}).reset_index()

fig = go.Figure()

fig.add_trace(go.Bar(
    x=position_stats['Position'],
    y=position_stats['Total_Points'],
    marker=dict(color=['red', 'brown', 'white','purple', 
                       'cyan', 'blue','orange']),
))

fig.update_layout(
    title='Points per Position',
    xaxis_title='Position',
    yaxis_title='Average Total Points',
    template='plotly_dark'
)
fig.show()

图1 各位置下的球员平均总得分


fig = go.Figure()

fig.add_trace(go.Histogram(x=df['Age'], marker_color='white'))

fig.update_layout(title='Distribution of Player Ages',
                  xaxis_title='Age',
                  yaxis_title='Count',
                  template='plotly_dark',
                  bargap=0.1,  # 设置柱子之间的间隔
                  bargroupgap=0.1)
                            
fig_age_histogram.show()

图2 球员年龄频数分布


fig = px.scatter(df, x='Age', y='Total_Points', color='Position',
                title='Player Age vs Total Points',
                labels={'Age': 'Age', 'Total_Points': 'Total Points'},
                template='plotly_dark')
fig.show()

图3 按年龄、位置划分下的球员总得分


fig = px.scatter(df, x='Age', y='Field_Goal_Percentage',
                 color='Position',
                 title='Player Age vs Field Goal Percentage',
                 labels={'Age': 'Age', 
                         'Field_Goal_Percentage': 
                         'Field Goal Percentage'},
                 template='plotly_dark')
fig.show()

图4 按年龄、位置划分下的投篮命中率


fig_assists = px.scatter(df, x='Age', y='Assists', color='Position',
                         title='Player Age vs Assists',
                         labels={'Age': 'Age', 'Assists': 'Assists'},
                         template='plotly_dark')
fig_assists.show()

图5 按年龄、位置划分下的总助攻


avg = df.groupby('Position')['NBA_Fantasy_Points'].mean().reset_index()

fig = go.Figure()

fig.add_trace(go.Bar(x=avg['Position'],
                    y=avg['NBA_Fantasy_Points'],
                    marker_color='white'))

fig.update_layout(title='Average Fantasy Points by Position',
                xaxis_title='Position',
                yaxis_title='Average Fantasy Points',
                template='plotly_dark')

fig.show()

图6 各位置下的球员平均虚拟得分


四 影响球员得分因素

在篮球中,球员通过出手投篮或罚篮获得个人得分。在比赛中,球员的得分可能受到其它因素的影响,如该球员的出场时间、抢到的篮板数、贡献的助攻数等,篮板数与助攻数可能无法从现象上反映对于球员得分的贡献,但纵观全局,所引发的蝴蝶效应是巨大的。

4.1 出场时间

fig= go.Figure()

fig.add_trace(go.Scatter(x=df['Minutes_Played'],
                         y=df['Total_Points'],
                        mode='markers', 
                        marker_color='yellow',
                        opacity=0.7))

fig.update_layout(title='Points vs. Minutes Played',
                xaxis_title='Minutes Played',
                yaxis_title='Total Points',
                template='plotly_dark')

fig.show()

图7 出场时间与总得分散点图

4.2 总篮板数

fig_scatter = px.scatter(df, x='Total_Rebounds', y='Total_Points',
                         title='Total Points vs Total Rebounds',
                         labels={'Total_Rebounds': 'Total Rebounds',
                                 'Total_Points': 'Total Points'},
                         template='plotly_dark',
                         color_discrete_sequence=['orange'])
fig_scatter.show()

图8 总篮板数与总得分散点图

4.3 总助攻数

fig_scatter = px.scatter(df, x='Assists', y='Total_Points',
                         title='Total Points vs Assists',
                         labels={'Assists': 'Assists', 
                                 'Total_Points': 'Total Points'},
                         template='plotly_dark',
                         color_discrete_sequence=['cyan'])
fig_scatter.show()

图9 总助攻数与总得分散点图

五 球员分析

5.1 球员对比分析

在NBA体育环境中,詹姆斯、杜兰特、库里等超级巨星拥有着大量球迷和广泛的商业价值,球迷与行业专家对于他们的赛场表现也相当关注,该节通过对比詹姆斯、杜兰特、库里、扬尼斯以及伦纳德五名超级巨星的各项数据,观察他们2023年的登场表现并做出评价。

球员对比分析步骤:

  • 选取球员并从数据集中抓取对应球员信息
  • 选定指标并设置最大值
  • 创建雷达图
  • 计算球员各项指标归一化的数据
  • 填充组件展示图像
radar_columns = ['Total_Points', 
                 'Total_Rebounds', 
                 'Assists', 'Steals', 
                 'Blocks', 'Minutes_Played']

selected_players_names = ['LeBron James', 
                          'Kevin Durant', 
                          'Stephen Curry', 
                          'Giannis Antetokounmpo', 
                          'Kawhi Leonard']
selected_players = df[df['Player_Name'].isin(selected_players_names)]
# 每个指标的最大值
max_values = {
    'Total_Points': 2225,
    'Total_Rebounds': 973,
    'Assists': 741,
    'Steals': 128,
    'Blocks': 193,
    'Minutes_Played': 2963
}

# 创建雷达图
fig_radar = go.Figure()

for index, player in selected_players.iterrows():
    # 计算每个球员在每个指标上的 r 值,使用指标的最大值来归一化
    r_values = [player[column] / max_values[column] for 
                column in radar_columns]
    fig_radar.add_trace(go.Scatterpolar(
        r=r_values,
        theta=radar_columns,
        fill='toself',
        name=player['Player_Name']
    ))
fig_radar.update_layout(
    title='Player Comparison - Overall Performance',
    template='plotly_dark',
    polar=dict(
        radialaxis=dict(visible=True, range=[0, 1]),
    ),
)
fig_radar.show()

图10 詹姆斯 图11 杜兰特
图12 库里 图13 扬尼斯
图14 伦纳德 图15 全明星
通过球员对比分析雷达图发现,司职大前锋的扬尼斯在总篮板数、总得分数以及出场时间上均位于五名球员之首,詹姆斯、杜兰特与伦纳德分别在助攻数、盖帽数与抢断数上位列首位,司职控球后卫、以三分见长的库里则在总得分上仅次于扬尼斯、位列第二名。
单从球员个人数据方面分析,2023年的扬尼斯可谓是做到了真正的攻防一体,作为雄鹿队的球队核心,扬尼斯在2023年常规赛的表现十分亮眼,无愧为一名超级巨星。

5.2 最佳防守球员DPOY

在NBA中,衡量一名球员的水平不仅仅参考其进攻能力,防守能力也同样重要。一名出色的NBA球员应该具备强大的意志品质,在进攻端高效发挥自己的得分能力,在防守端尽职尽责,尽可能不让对手得分,做到攻防一体,这才是巨星的衡量标准。

df['Defensive_Performance'] = df['Blocks'] + df['Steals'] + 
                              df['Defensive_Rebounds']/10
players = df.sort_values(by='Defensive_Performance', 
                         ascending=False).head(10)

fig_defending = go.Figure()
fig_defending.add_trace(go.Bar(x=players['Player_Name'],
                               y=players['Defensive_Performance'],
                               marker_color='white'))

fig_defending.update_layout(
    title='Top 10 Best Defending Players',
    xaxis_title='Player Name',
    yaxis_title='Defensive Performance',
    template='plotly_dark'
)
fig_defending.show()

图16 十大防守球员
在给定的衡量指标下,得到如上图的十大防守球员,其中2023赛季的mvp乔尔-恩比德位列榜中,除此之外,安东尼-戴维斯、鲁迪-戈贝尔以及小将爱德华兹也都入选了十大防守球员名单。2023赛季NBA官方评定的最佳防守球员为小贾伦-杰克逊,位于上图第二位。

5.3 最佳进攻球员

在进攻端,主要考虑球员的得分、进攻篮板(用于二次进攻)以及助攻数据,一名优秀的进攻球员,不仅具备自己单打得分的绝对实力,还应具备团队组织能力。考虑到中锋、大前锋在得分、助攻上的数据不比后卫,因此加入进攻篮板数据,用于更加均衡地评价全位置球员的进攻能力。

df['Attack_Performance'] = df['Total_Points'] +
                        df['Offensive_Rebounds']/10 + 
                        df['Assists'] * 2
players = df.sort_values(by='Attack_Performance', 
                                        ascending=False).head(10)

fig = go.Figure()
fig.add_trace(go.Bar(x=players['Player_Name'],
                    y=players['Attack_Performance'],
                    marker_color='white'))

fig.update_layout(
    title='Top 10 Best Attacking Players',
    xaxis_title='Player Name',
    yaxis_title='Attack Performance',
    template='plotly_dark'
)
fig.show()

图17 十大进攻球员
在给定的衡量指标下,特雷杨登上进攻球员榜首,他优秀的传控能力使得他在球场上能够完美地掌控球队,东契奇、约基奇也同样作为各自球队的核心登上进攻球员榜单,作为联盟中少有的两名欧洲球员,东契奇和约基奇经常能在球场上贡献大三双的表现,入选进攻球员榜单当之无愧。注意到,2023赛季mvp乔尔-恩比德同样也出现在榜单中,他也是唯一一位同时出现在十大防守球员榜单和十大进攻球员榜单中的角色,常规赛mvp当之无愧。

六 球队分析

6.1 球队球员数

fig = px.histogram(df, x='Team_Abbreviation', 
                   color_discrete_sequence=px.colors.qualitative.Vivid,
             title='Players teams counts', template='plotly_dark')
fig.show()

图18 球队球员数

6.2 球队胜负场

上述分析所依赖的的数据集为所有球员个人的数据,但是并没有包含球队整体的数据。接下来给出的数据为2023年球队的数据,用于对球队进行整体性的分析。

NBA2022-2023(1).xlsx
根据球队整体的数据,统计出全部球队在当年的胜负场数并利用excel绘制图表。从图表所反映的信息可知,2023年胜场数最高的球队为扬尼斯领衔的密尔沃基雄鹿队,有趣的是,在季后赛的第一轮,东部第一雄鹿队惨遭东部第八热火队爆冷淘汰。
图19 球队胜负场

6.3 球队场均得分

图20 球队场均得分
在对球队胜负场进行统计分析后,统计球队场均得分并绘制图表,发现场均得分最高的为西部的国王队,各球队之间的场均得分差距并不大。因此,在比赛中做好防守、控制失误、尽量少让对手得分对于赢得比赛显得至关重要。

七 变量分析

7.1 箱线图

回到原来的球员个人数据集,按照单一变量划分数据,绘制箱线图反映各变量下数据分布的中心位置和散布范围。

column_to_exclude = ['Player_Name', 'Position', 'Team_Abbreviation']

all_columns = df.columns

columns = all_columns.drop(column_to_exclude).values

num_columns = len(columns)
num_rows = (num_columns + 1) // 2

fig, axes = plt.subplots(num_rows, 2, figsize=(10, 40),dpi = 300)

colors = sns.color_palette("Set3", num_columns)
for i, column in enumerate(columns):
    row = i // 2
    col = i % 2
    sns.boxplot(data=df[column], ax=axes[row, col], color=colors[i])
    axes[row, col].set_title(f'Box Plot of {column}')

if num_columns % 2 != 0:
    axes[-1, -1].axis('off')

plt.tight_layout()
plt.show()

图21 箱线图

7.2 相关性

在考虑各变量间的相关性时,需要剔除各变量下数据的异常值,这里考虑剔除三分出手数低于10次、罚球次数少于10次的球员数据,用于保证数据的质量和准确性,以便进行进一步的分析和建模。

outliers_condition = ((df['Three_Point_FG_Attempted'] < 10) |
                      (df['Free_Throws_Attempted'] < 10))

df = df[~outliers_condition]
correlation_matrix = df.iloc[:, 3:].corr()

fig = go.Figure(data=go.Heatmap(
    z=correlation_matrix.values,
    x=correlation_matrix.columns,
    y=correlation_matrix.index,
    colorscale='Oranges',
))

fig.update_layout(
    title='Correlation Heatmap',
    xaxis_title='Features',
    yaxis_title='Features',
    height=1000,
    template='plotly_dark'
)
fig.show()

图22 相关性热力图

八 建模预处理

8.1 变量处理

根据7.2相关性分析中热力图所反映的各变量间的相关系数,考虑到多重共线性,剔除高相关性的变量,旨在潜在地提高后续模型的性能。

df.drop(columns=['Player_Name', 'Position', 'Team_Abbreviation', 
                 'Field_Goals_Made', 'Field_Goals_Attempted',
                 'Three_Point_FG_Made', 'Three_Point_FG_Attempted',
                 'Three_Point_FG_Percentage',
                 'NBA_Fantasy_Points', 'Double_Doubles', 
                 'Free_Throws_Attempted',
                 'Triple_Doubles', 'Offensive_Rebounds', 
                 'Defensive_Rebounds'],
        inplace=True)

8.2 训练集与测试集

在进行模型的建立与计算之前,需要先划分训练集和测试集。

X = df.drop('Total_Points',axis = 1)
y = df['Total_Points']

X_train, X_test, y_train, y_test = train_test_split(X, y, 
                                                    test_size=0.2,
                                                    random_state=43)
其中自变量为除Total_Points以外的剩余全部变量,因变量为Total_Points,切分20%的数据集为测试集,设置随机采样种子,获得数据的训练集和测试集。

九 模型预测

9.1 线性回归

LRmodel = LinearRegression(fit_intercept=True)
LRmodel.fit(X_train, y_train)
y_pred = LRmodel.predict(X_test)
print(r2_score(y_test, y_pred))
0.9508670290220914
# 绘制预测值与实际值的散点图
plt.scatter(y_test, y_pred, color='blue', label='Actual vs Predicted')
plt.title('Actual vs Predicted')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.plot(y_test,y_test,color='red',label='y=x')
plt.legend()
plt.show()

图23 线性回归预测

9.2 KNN回归

KNNmodel = KNeighborsRegressor()

param_grid = {
    'n_neighbors': range(3, 11, 1),
    'weights': ['uniform', 'distance'],
    'p': [1, 2]
}

grid_search = GridSearchCV(estimator=KNNmodel, param_grid=param_grid,
                           scoring='r2', cv=5)
grid_search.fit(X_train, y_train)

best_params = grid_search.best_params_
print(f"Best hyperparameters: {best_params}")

best_model = grid_search.best_estimator_
y_pred = best_model.predict(X_test)
r2 = r2_score(y_test, y_pred)
print(f"R2 score: {r2}")
其中,参数网络param_grid中的参数'n_neighbors'为KNN回归时考虑的邻居个数,'weights'表示邻居权重系数,'p'是明可夫斯基距离参数,当p取1时,使用曼哈顿距离,当p取2时,使用欧氏距离。设定参数网络后,创建GridSearchCV对象,在训练集上拟合得到最优参数并输出,在测试集上预测,输出可决系数。
Best hyperparameters: {'n_neighbors': 8, 'p': 1, 'weights': 'distance'}
R2 score: 0.8894559661663552
# 绘制预测值与实际值的散点图
plt.scatter(y_test, y_pred, color='blue', label='Actual vs Predicted')
plt.title('Actual vs Predicted')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.plot(y_test,y_test,color='red',label='y=x')
plt.legend()
plt.show()

图24 KNN回归预测

9.3 决策树回归

param_grid = {
    'max_depth': [5, 10, 15, 20],
    'min_samples_split': [15, 20, 25, 30, 35],
    'min_samples_leaf': [1, 2, 4, 6, 8, 10]
}
DTRmodel = DecisionTreeRegressor()
grid_search = GridSearchCV(estimator=DTRmodel, param_grid=param_grid,
                           scoring='r2', cv=5)
grid_search.fit(X_train, y_train)

best_params = grid_search.best_params_
print(f"Best hyperparameters: {best_params}")

best_model = grid_search.best_estimator_
y_pred = best_model.predict(X_test)
test_r2 = r2_score(y_test, y_pred)
print(f"R2 score on test set: {test_r2}")

其中,参数网络中限制了树的最大深度、叶子结点的最小样本数目,目的是为了防止过拟合,提高模型的泛化能力。

Best hyperparameters: {'max_depth': 10, 'min_samples_leaf': 2, 
                                        'min_samples_split': 15}
R2 score on test set: 0.9142067316656353
# 绘制预测值与实际值的散点图
plt.scatter(y_test, y_pred, color='blue', label='Actual vs Predicted')
plt.title('Actual vs Predicted')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.plot(y_test,y_test,color='red',label='y=x')
plt.legend()
plt.show()

图25 决策树回归预测

9.4 随机森林回归

RFRmodel = RandomForestRegressor()

param_grid = {
    'n_estimators': range(80,160,10),
    'max_depth': range(1,20,2),
}

grid_search = GridSearchCV(estimator=RFRmodel, param_grid=param_grid,
                           scoring='r2', cv=5)
grid_search.fit(X_train, y_train)

best_params = grid_search.best_params_
print(f"Best hyperparameters: {best_params}")

best_RFRmodel = grid_search.best_estimator_
y_pred = best_RFRmodel.predict(X_test)
test_r2 = r2_score(y_test, y_pred)
print(f"R2 score on test set: {test_r2}")
Best R2 score: 0.9307685638961534
Best hyperparameters: {'max_depth': 13, 'n_estimators': 90}
R2 score on test set: 0.9550547280293766
# 绘制预测值与实际值的散点图
plt.scatter(y_test, y_pred, color='blue', label='Actual vs Predicted')
plt.title('Actual vs Predicted')
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.plot(y_test,y_test,color='red',label='y=x')
plt.legend()
plt.show()

图26 随机森林回归

9.5 模型预测解释

对比线性回归、KNN回归、决策树回归以及随机森林回归模型的预测结果,发现线性回归与随机森林回归的预测结果相对较好,KNN的预测结果相对较差。同时,对于每个模型的预测结果而言,随着球员总得分的不断上升,预测的偏差也会随之增大。从结果上而言,随机森林的预测结果最优,它通过集成多棵决策树以实现最优解,但是,它在运算时间上消耗最大。

结语

本篇报告对于2023赛季的NBA球员个人表现以及球队整体表现进行了数据分析,包括对于数据集的探索以发现变量间可能具有的相关关系、对数据进行可视化展示、分析影响球员个人得分的因素、对比球员数据评价球员表现、分析球队赛季的整体表现,最后,使用机器学习技术对球员个人得分进行预测并得出结论。

  • 44
    点赞
  • 60
    收藏
    觉得还不错? 一键收藏
  • 9
    评论
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值