世界杯预测

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>

项目介绍

2022年卡塔尔世界杯(FIFA World Cup Qatar 2022)是第二十二届国际足联世界杯,于当地时间2022年11月20日(北京时间11月21日)至12月18日在卡塔尔境内5座城市中的8座球场举行(赛程将原本的32天减至29天)。卡塔尔是继日本、韩国后,第三个主办世界杯足球赛的亚洲国家,也是首个主办的伊斯兰国家,同时亦是二战后首个从未晋级过世界杯决赛圈的主办国。本届世界杯总花费高达2290亿美元,被称为“史上最贵世界杯”。

项目使用历史数据国际足联世界排名1992-20221872年至2022年国际足球成绩完成世界杯预测。

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import joblib
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
from sklearn.metrics import accuracy_score, roc_curve, roc_auc_score, confusion_matrix

数据处理

results.csv包括以下列:

  • date - 比赛日期
  • home_team - 主队名称
  • away_team - 客队名称
  • home_score - 全场主队得分,包括加时赛,不包括点球大战
  • away_score - 全场客队得分,包括加时赛,不包括点球大战
  • tournament - 比赛名称
  • city - 比赛所在的城市/城镇/行政单位的名称
  • country - 比赛所在国的名称
  • neutral - TRUE/FALSE 列,指示比赛是否在中立场地进行
# 读取数据
results = pd.read_csv('/home/aistudio/work/results.csv', parse_dates=['date'])
results.head()
datehome_teamaway_teamhome_scoreaway_scoretournamentcitycountryneutral
01872-11-30ScotlandEngland00FriendlyGlasgowScotlandFalse
11873-03-08EnglandScotland42FriendlyLondonEnglandFalse
21874-03-07ScotlandEngland21FriendlyGlasgowScotlandFalse
31875-03-06EnglandScotland22FriendlyLondonEnglandFalse
41876-03-04ScotlandEngland30FriendlyGlasgowScotlandFalse
# 查看数据信息
results.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 44289 entries, 0 to 44288
Data columns (total 9 columns):
 #   Column      Non-Null Count  Dtype         
---  ------      --------------  -----         
 0   date        44289 non-null  datetime64[ns]
 1   home_team   44289 non-null  object        
 2   away_team   44289 non-null  object        
 3   home_score  44289 non-null  int64         
 4   away_score  44289 non-null  int64         
 5   tournament  44289 non-null  object        
 6   city        44289 non-null  object        
 7   country     44289 non-null  object        
 8   neutral     44289 non-null  bool          
dtypes: bool(1), datetime64[ns](1), int64(2), object(5)
memory usage: 2.7+ MB
# 检查数据是否缺失
results.isna().sum()
date          0
home_team     0
away_team     0
home_score    0
away_score    0
tournament    0
city          0
country       0
neutral       0
dtype: int64
# 筛选1992-2022世界杯预选赛和世界杯正式赛
fifa_data = results[(results['date'] >= '1992-12-31') & ((results['tournament'] == 'FIFA World Cup') | (results['tournament'] == 'FIFA World Cup qualification'))]
fifa_data = fifa_data.drop(['tournament'], axis=1)
fifa_data = fifa_data.reset_index(drop=True)
fifa_data.head()
datehome_teamaway_teamhome_scoreaway_scorecitycountryneutral
01993-01-10AngolaZimbabwe11LuandaAngolaFalse
11993-01-10DR CongoCameroon12KinshasaZaïreFalse
21993-01-16South AfricaNigeria00JohannesburgSouth AfricaFalse
31993-01-16TanzaniaZambia13MwanzaTanzaniaFalse
41993-01-17BeninTunisia05CotonouBeninFalse
# 查看数据信息
fifa_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6359 entries, 0 to 6358
Data columns (total 8 columns):
 #   Column      Non-Null Count  Dtype         
---  ------      --------------  -----         
 0   date        6359 non-null   datetime64[ns]
 1   home_team   6359 non-null   object        
 2   away_team   6359 non-null   object        
 3   home_score  6359 non-null   int64         
 4   away_score  6359 non-null   int64         
 5   city        6359 non-null   object        
 6   country     6359 non-null   object        
 7   neutral     6359 non-null   bool          
dtypes: bool(1), datetime64[ns](1), int64(2), object(4)
memory usage: 354.1+ KB

fifa_ranking.csv包括以下列:

  • rank - 当前国家/地区排名
  • country_full - 国家全名
  • country_abrv - 国家缩写
  • total_points - 当前总分
  • previous_points - 上次评分的总分
  • rank_change - 自上次发布以来排名如何变化
  • confederation - 国际足联联合会
  • rank_date - 评级计算日期
# 读取数据
fifa_ranking = pd.read_csv('/home/aistudio/work/fifa_ranking.csv', parse_dates=['rank_date'])
fifa_ranking.head()
rankcountry_fullcountry_abrvtotal_pointsprevious_pointsrank_changeconfederationrank_date
074MadagascarMAD18.00.00CAF1992-12-31
152QatarQAT27.00.00AFC1992-12-31
251SenegalSEN27.00.00CAF1992-12-31
350El SalvadorSLV28.00.00CONCACAF1992-12-31
449Korea RepublicKOR28.00.00AFC1992-12-31
# 替换国家全名: 部分国家全名在fifa_ranking和results中存在差异
fifa_ranking['country_full'] = fifa_ranking['country_full'].str.replace('Brunei Darussalam', 'Brunei').str.replace('Cape Verde Islands', 'Cape Verde').str.replace('chinese taipei', 'taiwan').str.replace('Congo DR', 'DR Congo').str.replace("Côte d'Ivoire", 'Ivory Coast').str.replace('Curacao', 'Curaçao').str.replace('IR Iran', 'Iran').str.replace('Kyrgyz Republic', 'Kyrgyzstan').str.replace('Korea DPR', 'North Korea').str.replace('Korea Republic', 'South Korea').str.replace('St Kitts and Nevis', 'Saint Kitts and Nevis').str.replace('St Lucia', 'Saint Lucia').str.replace('St Vincent and the Grenadines', 'Saint Vincent and the Grenadines').str.replace('Sao Tome e Principe', 'São Tomé and Príncipe').str.replace('US Virgin Islands', 'United States Virgin Islands').str.replace('USA', 'United States')
# fifa_ranking以日期为索引、根据国家分组、按天重新采样、最后重置索引
fifa_ranking = fifa_ranking.set_index(['rank_date']).groupby(['country_full'], group_keys=False).resample('D').fillna(method='ffill').reset_index()
fifa_ranking.head()
rank_daterankcountry_fullcountry_abrvtotal_pointsprevious_pointsrank_changeconfederation
02003-01-15204AfghanistanAFG7.00.00AFC
12003-01-16204AfghanistanAFG7.00.00AFC
22003-01-17204AfghanistanAFG7.00.00AFC
32003-01-18204AfghanistanAFG7.00.00AFC
42003-01-19204AfghanistanAFG7.00.00AFC
# 合并数据: 联合results和fifa_ranking
fifa_data = fifa_data.merge(fifa_ranking[['country_full', 'total_points', 'previous_points', 'rank', 'rank_change', 'rank_date']], left_on=['date', 'home_team'], right_on=['rank_date', 'country_full']).drop(['rank_date', 'country_full'], axis=1)
fifa_data = fifa_data.merge(fifa_ranking[['country_full', 'total_points', 'previous_points', 'rank', 'rank_change', 'rank_date']], left_on=['date', 'away_team'], right_on=['rank_date', 'country_full'], suffixes=('_home', '_away')).drop(['rank_date', 'country_full'], axis=1)
fifa_data.head()
datehome_teamaway_teamhome_scoreaway_scorecitycountryneutraltotal_points_homeprevious_points_homerank_homerank_change_hometotal_points_awayprevious_points_awayrank_awayrank_change_away
01993-01-10AngolaZimbabwe11LuandaAngolaFalse10.00.0102027.00.0540
11993-01-16South AfricaNigeria00JohannesburgSouth AfricaFalse5.00.0124050.00.0130
21993-01-16TanzaniaZambia13MwanzaTanzaniaFalse15.00.080038.00.0320
31993-01-17BeninTunisia05CotonouBeninFalse4.00.0127035.00.0380
41993-01-17BotswanaIvory Coast00GaboroneBotswanaFalse2.00.0139041.00.0270
# 查看数据信息
fifa_data.info()
<class 'pandas.core.frame.DataFrame'>
Int64Index: 6052 entries, 0 to 6051
Data columns (total 16 columns):
 #   Column                Non-Null Count  Dtype         
---  ------                --------------  -----         
 0   date                  6052 non-null   datetime64[ns]
 1   home_team             6052 non-null   object        
 2   away_team             6052 non-null   object        
 3   home_score            6052 non-null   int64         
 4   away_score            6052 non-null   int64         
 5   city                  6052 non-null   object        
 6   country               6052 non-null   object        
 7   neutral               6052 non-null   bool          
 8   total_points_home     6052 non-null   float64       
 9   previous_points_home  6052 non-null   float64       
 10  rank_home             6052 non-null   int64         
 11  rank_change_home      6052 non-null   int64         
 12  total_points_away     6052 non-null   float64       
 13  previous_points_away  6052 non-null   float64       
 14  rank_away             6052 non-null   int64         
 15  rank_change_away      6052 non-null   int64         
dtypes: bool(1), datetime64[ns](1), float64(4), int64(6), object(4)
memory usage: 762.4+ KB
# 检查数据是否缺失
fifa_data.isna().sum()
date                    0
home_team               0
away_team               0
home_score              0
away_score              0
city                    0
country                 0
neutral                 0
total_points_home       0
previous_points_home    0
rank_home               0
rank_change_home        0
total_points_away       0
previous_points_away    0
rank_away               0
rank_change_away        0
dtype: int64

特征工程

特征工程

  • result - 比赛结果 0: 主队胜 1: 客队胜 2: 平局
  • home_points - 主队得分 3: 主队胜 0: 客队胜 1: 平局
  • away_points - 客队得分 3: 客队胜 0: 主队胜 1: 平局
  • target - 预测目标 0: 主队胜 1: 客队胜或者平局
# 特征工程
def get_result(home_score, away_score):
    if home_score > away_score:
        return pd.Series([0, 3, 0, 0])
    elif home_score < away_score:
        return pd.Series([1, 0, 3, 1])
    else:
        return pd.Series([2, 1, 1, 1])

results = fifa_data.apply(lambda x: get_result(x['home_score'], x['away_score']), axis=1)
fifa_data[['result', 'home_points', 'away_points', 'target']] = results
fifa_data.head()
datehome_teamaway_teamhome_scoreaway_scorecitycountryneutraltotal_points_homeprevious_points_homerank_homerank_change_hometotal_points_awayprevious_points_awayrank_awayrank_change_awayresulthome_pointsaway_pointstarget
01993-01-10AngolaZimbabwe11LuandaAngolaFalse10.00.0102027.00.05402111
11993-01-16South AfricaNigeria00JohannesburgSouth AfricaFalse5.00.0124050.00.01302111
21993-01-16TanzaniaZambia13MwanzaTanzaniaFalse15.00.080038.00.03201031
31993-01-17BeninTunisia05CotonouBeninFalse4.00.0127035.00.03801031
41993-01-17BotswanaIvory Coast00GaboroneBotswanaFalse2.00.0139041.00.02702111
# 特征编码
label_encoder = LabelEncoder()
labels = ['date', 'home_team', 'away_team', 'city', 'country']
for label in labels:
    fifa_data[f'{label}_encoding'] = label_encoder.fit_transform(fifa_data[label])
fifa_data.head()
datehome_teamaway_teamhome_scoreaway_scorecitycountryneutraltotal_points_homeprevious_points_home...rank_change_awayresulthome_pointsaway_pointstargetdate_encodinghome_team_encodingaway_team_encodingcity_encodingcountry_encoding
01993-01-10AngolaZimbabwe11LuandaAngolaFalse10.00.0...02111052063454
11993-01-16South AfricaNigeria00JohannesburgSouth AfricaFalse5.00.0...021111170136274171
21993-01-16TanzaniaZambia13MwanzaTanzaniaFalse15.00.0...010311183205410184
31993-01-17BeninTunisia05CotonouBeninFalse4.00.0...0103122118915720
41993-01-17BotswanaIvory Coast00GaboroneBotswanaFalse2.00.0...021112269421625

5 rows × 25 columns

# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(fifa_data.corr(), annot=True, linewidths=0.2, square=True)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-n7IVlUkB-1687232982510)(main_files/main_20_0.png)]

# 删除编码: 相关性太低
fifa_data = fifa_data.drop(['date_encoding', 'home_team_encoding', 'away_team_encoding', 'city_encoding', 'country_encoding'], axis=1)

特征工程

  • rank_diff - 排名差异
  • rank_change_diff - 排名变化差异
  • total_points_diff - 总分差异
  • previous_points_diff - 上次评分的总分差异
  • home_points2rank - 主队得分 / 客队排名
  • away_points2rank - 客队得分 / 主队排名
  • points2rank_diff - points2rank差异
# 特征工程
fifa_data['rank_diff'] = fifa_data['rank_home'] - fifa_data['rank_away']
fifa_data['rank_change_diff'] = fifa_data['rank_change_home'] - fifa_data['rank_change_away']
fifa_data['total_points_diff'] = fifa_data['total_points_home'] - fifa_data['total_points_away']
fifa_data['previous_points_diff'] = fifa_data['previous_points_home'] - fifa_data['previous_points_away']
fifa_data['home_points2rank'] = fifa_data['home_points'] / fifa_data['rank_away']
fifa_data['away_points2rank'] = fifa_data['away_points'] / fifa_data['rank_home']
fifa_data['points2rank_diff'] = fifa_data['home_points2rank'] - fifa_data['away_points2rank']
# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(fifa_data.corr(), annot=True, linewidths=0.2, square=True)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jRXkEsOh-1687232982511)(main_files/main_24_0.png)]

# 分离数据: 优化特征工程
home_team = fifa_data[['date', 'home_team', 'home_score', 'away_score', 'total_points_home', 'total_points_away', 'previous_points_home', 'previous_points_away', 'rank_home', 'rank_away', 'home_points', 'away_points', 'home_points2rank', 'away_points2rank', 'result']]
away_team = fifa_data[['date', 'away_team', 'away_score', 'home_score', 'total_points_away', 'total_points_home', 'previous_points_away', 'previous_points_home', 'rank_away', 'rank_home', 'away_points', 'home_points', 'away_points2rank', 'home_points2rank', 'result']]
home_team.columns = [h.replace('home_', '').replace('_home', '').replace('away_', 'rival_').replace('_away', '_rival') for h in home_team.columns]
away_team.columns = [a.replace('away_', '').replace('_away', '').replace('home_', 'rival_').replace('_home', '_rival') for a in away_team.columns]
# 合并数据: 优化特征工程
team_data = home_team.append(away_team)
data_copy = team_data.copy()
team_data.head()
dateteamscorerival_scoretotal_pointstotal_points_rivalprevious_pointsprevious_points_rivalrankrank_rivalpointsrival_pointspoints2rankrival_points2rankresult
01993-01-10Angola1110.027.00.00.010254110.0185190.0098042
11993-01-16South Africa005.050.00.00.012413110.0769230.0080652
21993-01-16Tanzania1315.038.00.00.08032030.0000000.0375001
31993-01-17Benin054.035.00.00.012738030.0000000.0236221
41993-01-17Botswana002.041.00.00.013927110.0370370.0071942

特征工程

  • mean_goals - 平均进球
  • mean_goals_last5 - 最近五场平均进球
  • rival_mean_goals - 对手平均进球
  • rival_mean_goals_last5 - 对手最近五场平均进球
  • mean_rank - 平均排名
  • mean_rank_last5 - 最近五场平均排名
  • rival_mean_rank - 对手平均排名
  • rival_mean_rank_last5 - 对手最近五场平均排名
  • mean_points - 平均得分
  • mean_points_last5 - 最近五场平均得分
  • rival_mean_points - 对手平均得分
  • rival_mean_points_last5 - 对手最近五场平均得分
  • mean_points2rank - 平均points2rank
  • mean_points2rank_last5 - 最近五场平均points2rank
  • rival_mean_points2rank - 对手平均points2rank
  • rival_mean_points2rank_last5 - 对手最近五场平均points2rank
# 特征工程
team_values = []
for idx, row in team_data.iterrows():
    team = row['team']
    date = row['date']
    pasts = team_data.loc[(team_data['team'] == team) & (team_data['date'] < date)].sort_values(by=['date'], ascending=False)
    last5 = pasts.head(5)
    mean_goals = pasts['score'].mean()
    mean_goals_last5 = last5['score'].mean()
    rival_mean_goals = pasts['rival_score'].mean()
    rival_mean_goals_last5 = last5['rival_score'].mean()
    mean_rank = pasts['rank'].mean()
    mean_rank_last5 = last5['rank'].mean()
    rival_mean_rank = pasts['rank_rival'].mean()
    rival_mean_rank_last5 = last5['rank_rival'].mean()
    mean_points = pasts['points'].mean()
    mean_points_last5 = last5['points'].mean()
    rival_mean_points = pasts['rival_points'].mean()
    rival_mean_points_last5 = last5['rival_points'].mean()
    mean_points2rank = pasts['points2rank'].mean()
    mean_points2rank_last5 = last5['points2rank'].mean()
    rival_mean_points2rank = pasts['rival_points2rank'].mean()
    rival_mean_points2rank_last5 = last5['rival_points2rank'].mean()
    team_values.append([mean_goals, mean_goals_last5, rival_mean_goals, rival_mean_goals_last5, mean_rank, mean_rank_last5, rival_mean_rank, rival_mean_rank_last5, mean_points, mean_points_last5, rival_mean_points, rival_mean_points_last5, mean_points2rank, mean_points2rank_last5, rival_mean_points2rank, rival_mean_points2rank_last5])
# 合并数据
team_columns = ['mean_goals', 'mean_goals_last5', 'rival_mean_goals', 'rival_mean_goals_last5', 'mean_rank', 'mean_rank_last5', 'rival_mean_rank', 'rival_mean_rank_last5', 'mean_points', 'mean_points_last5', 'rival_mean_points', 'rival_mean_points_last5', 'mean_points2rank', 'mean_points2rank_last5', 'rival_mean_points2rank', 'rival_mean_points2rank_last5']
team_value = pd.DataFrame(team_values, columns=team_columns)
team_data = pd.concat([team_data.reset_index(drop=True), team_value], axis=1, ignore_index=False)
team_data.head()
dateteamscorerival_scoretotal_pointstotal_points_rivalprevious_pointsprevious_points_rivalrankrank_rival...rival_mean_rankrival_mean_rank_last5mean_pointsmean_points_last5rival_mean_pointsrival_mean_points_last5mean_points2rankmean_points2rank_last5rival_mean_points2rankrival_mean_points2rank_last5
01993-01-10Angola1110.027.00.00.010254...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
11993-01-16South Africa005.050.00.00.012413...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
21993-01-16Tanzania1315.038.00.00.08032...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
31993-01-17Benin054.035.00.00.012738...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN
41993-01-17Botswana002.041.00.00.013927...NaNNaNNaNNaNNaNNaNNaNNaNNaNNaN

5 rows × 31 columns

# 查看数据信息
team_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 12104 entries, 0 to 12103
Data columns (total 31 columns):
 #   Column                        Non-Null Count  Dtype         
---  ------                        --------------  -----         
 0   date                          12104 non-null  datetime64[ns]
 1   team                          12104 non-null  object        
 2   score                         12104 non-null  int64         
 3   rival_score                   12104 non-null  int64         
 4   total_points                  12104 non-null  float64       
 5   total_points_rival            12104 non-null  float64       
 6   previous_points               12104 non-null  float64       
 7   previous_points_rival         12104 non-null  float64       
 8   rank                          12104 non-null  int64         
 9   rank_rival                    12104 non-null  int64         
 10  points                        12104 non-null  int64         
 11  rival_points                  12104 non-null  int64         
 12  points2rank                   12104 non-null  float64       
 13  rival_points2rank             12104 non-null  float64       
 14  result                        12104 non-null  int64         
 15  mean_goals                    11897 non-null  float64       
 16  mean_goals_last5              11897 non-null  float64       
 17  rival_mean_goals              11897 non-null  float64       
 18  rival_mean_goals_last5        11897 non-null  float64       
 19  mean_rank                     11897 non-null  float64       
 20  mean_rank_last5               11897 non-null  float64       
 21  rival_mean_rank               11897 non-null  float64       
 22  rival_mean_rank_last5         11897 non-null  float64       
 23  mean_points                   11897 non-null  float64       
 24  mean_points_last5             11897 non-null  float64       
 25  rival_mean_points             11897 non-null  float64       
 26  rival_mean_points_last5       11897 non-null  float64       
 27  mean_points2rank              11897 non-null  float64       
 28  mean_points2rank_last5        11897 non-null  float64       
 29  rival_mean_points2rank        11897 non-null  float64       
 30  rival_mean_points2rank_last5  11897 non-null  float64       
dtypes: datetime64[ns](1), float64(22), int64(7), object(1)
memory usage: 2.9+ MB
# 分离数据
home_team_data = team_data.iloc[:int(team_data.shape[0] / 2), :]
away_team_data = team_data.iloc[int(team_data.shape[0] / 2):, :]
away_team_data.tail()
dateteamscorerival_scoretotal_pointstotal_points_rivalprevious_pointsprevious_points_rivalrankrank_rival...rival_mean_rankrival_mean_rank_last5mean_pointsmean_points_last5rival_mean_pointsrival_mean_points_last5mean_points2rankmean_points2rank_last5rival_mean_points2rankrival_mean_points2rank_last5
120992022-06-01Ukraine311535.081472.661535.081471.822739...62.43373559.01.7590361.80.8915660.60.0684960.0934120.0283190.023407
121002022-06-05Ukraine011535.081588.081535.081578.012718...62.15476242.01.7738102.20.8809520.40.0685960.1071830.0279820.015407
121012022-06-07Australia211462.291356.991486.861353.104268...81.50476265.61.9809521.00.8095241.60.0296320.0113210.0217020.044029
121022022-06-13Peru001562.321462.291563.451486.862242...34.66165435.61.0902262.01.6766920.80.0674190.0658480.0447020.036364
121032022-06-14New Zealand011206.071503.091161.661464.0610131...109.125000156.21.9000003.00.9250000.00.0219970.0192610.0101790.000000

5 rows × 31 columns

# 分离数据
home_team_data = home_team_data[home_team_data.columns[-16:]]
away_team_data = away_team_data[away_team_data.columns[-16:]]
home_team_data.columns = ['home_' + str(col) for col in home_team_data.columns]
away_team_data.columns = ['away_' + str(col) for col in away_team_data.columns]
away_team_data.tail()
away_mean_goalsaway_mean_goals_last5away_rival_mean_goalsaway_rival_mean_goals_last5away_mean_rankaway_mean_rank_last5away_rival_mean_rankaway_rival_mean_rank_last5away_mean_pointsaway_mean_points_last5away_rival_mean_pointsaway_rival_mean_points_last5away_mean_points2rankaway_mean_points2rank_last5away_rival_mean_points2rankaway_rival_mean_points2rank_last5
120991.4939761.60.8072291.039.26506026.062.43373559.01.7590361.80.8915660.60.0684960.0934120.0283190.023407
121001.5119051.80.8095240.839.11904826.462.15476242.01.7738102.20.8809520.40.0685960.1071830.0279820.015407
121012.5142861.40.8095241.242.88571435.681.50476265.61.9809521.00.8095241.60.0296320.0113210.0217020.044029
121021.0150381.21.4661650.646.75188022.434.66165435.61.0902262.01.6766920.80.0674190.0658480.0447020.036364
121032.2000003.60.9000000.2101.725000111.0109.125000156.21.9000003.00.9250000.00.0219970.0192610.0101790.000000
# 合并数据
team_data = pd.concat([home_team_data, away_team_data.reset_index(drop=True)], axis=1, ignore_index=False)
fifa_data = pd.concat([fifa_data, team_data.reset_index(drop=True)], axis=1, ignore_index=False)
fifa_data.columns
Index(['date', 'home_team', 'away_team', 'home_score', 'away_score', 'city',
       'country', 'neutral', 'total_points_home', 'previous_points_home',
       'rank_home', 'rank_change_home', 'total_points_away',
       'previous_points_away', 'rank_away', 'rank_change_away', 'result',
       'home_points', 'away_points', 'target', 'rank_diff', 'rank_change_diff',
       'total_points_diff', 'previous_points_diff', 'home_points2rank',
       'away_points2rank', 'points2rank_diff', 'home_mean_goals',
       'home_mean_goals_last5', 'home_rival_mean_goals',
       'home_rival_mean_goals_last5', 'home_mean_rank', 'home_mean_rank_last5',
       'home_rival_mean_rank', 'home_rival_mean_rank_last5',
       'home_mean_points', 'home_mean_points_last5', 'home_rival_mean_points',
       'home_rival_mean_points_last5', 'home_mean_points2rank',
       'home_mean_points2rank_last5', 'home_rival_mean_points2rank',
       'home_rival_mean_points2rank_last5', 'away_mean_goals',
       'away_mean_goals_last5', 'away_rival_mean_goals',
       'away_rival_mean_goals_last5', 'away_mean_rank', 'away_mean_rank_last5',
       'away_rival_mean_rank', 'away_rival_mean_rank_last5',
       'away_mean_points', 'away_mean_points_last5', 'away_rival_mean_points',
       'away_rival_mean_points_last5', 'away_mean_points2rank',
       'away_mean_points2rank_last5', 'away_rival_mean_points2rank',
       'away_rival_mean_points2rank_last5'],
      dtype='object')
# 分离数据
fifa_data = fifa_data[['date', 'home_team', 'away_team', 'rank_home', 'rank_away', 'home_score', 'away_score', 'result', 'rank_diff', 'rank_change_diff', 'total_points_diff', 'previous_points_diff', 'points2rank_diff', 'home_mean_goals', 'home_mean_goals_last5', 'home_rival_mean_goals', 'home_rival_mean_goals_last5', 'home_mean_rank', 'home_mean_rank_last5', 'home_rival_mean_rank', 'home_rival_mean_rank_last5', 'home_mean_points', 'home_mean_points_last5', 'home_rival_mean_points', 'home_rival_mean_points_last5', 'home_mean_points2rank', 'home_mean_points2rank_last5', 'home_rival_mean_points2rank', 'home_rival_mean_points2rank_last5', 'away_mean_goals', 'away_mean_goals_last5', 'away_rival_mean_goals', 'away_rival_mean_goals_last5', 'away_mean_rank', 'away_mean_rank_last5', 'away_rival_mean_rank', 'away_rival_mean_rank_last5', 'away_mean_points', 'away_mean_points_last5', 'away_rival_mean_points', 'away_rival_mean_points_last5', 'away_mean_points2rank', 'away_mean_points2rank_last5', 'away_rival_mean_points2rank', 'away_rival_mean_points2rank_last5', 'target']]
fifa_data.head()
datehome_teamaway_teamrank_homerank_awayhome_scoreaway_scoreresultrank_diffrank_change_diff...away_rival_mean_rank_last5away_mean_pointsaway_mean_points_last5away_rival_mean_pointsaway_rival_mean_points_last5away_mean_points2rankaway_mean_points2rank_last5away_rival_mean_points2rankaway_rival_mean_points2rank_last5target
01993-01-10AngolaZimbabwe10254112480...NaNNaNNaNNaNNaNNaNNaNNaNNaN1
11993-01-16South AfricaNigeria124130021110...NaNNaNNaNNaNNaNNaNNaNNaNNaN1
21993-01-16TanzaniaZambia8032131480...NaNNaNNaNNaNNaNNaNNaNNaNNaN1
31993-01-17BeninTunisia12738051890...NaNNaNNaNNaNNaNNaNNaNNaNNaN1
41993-01-17BotswanaIvory Coast139270021120...NaNNaNNaNNaNNaNNaNNaNNaNNaN1

5 rows × 46 columns

# 检查数据是否缺失
fifa_data.isna().sum()
date                                   0
home_team                              0
away_team                              0
rank_home                              0
rank_away                              0
home_score                             0
away_score                             0
result                                 0
rank_diff                              0
rank_change_diff                       0
total_points_diff                      0
previous_points_diff                   0
points2rank_diff                       0
home_mean_goals                      101
home_mean_goals_last5                101
home_rival_mean_goals                101
home_rival_mean_goals_last5          101
home_mean_rank                       101
home_mean_rank_last5                 101
home_rival_mean_rank                 101
home_rival_mean_rank_last5           101
home_mean_points                     101
home_mean_points_last5               101
home_rival_mean_points               101
home_rival_mean_points_last5         101
home_mean_points2rank                101
home_mean_points2rank_last5          101
home_rival_mean_points2rank          101
home_rival_mean_points2rank_last5    101
away_mean_goals                      106
away_mean_goals_last5                106
away_rival_mean_goals                106
away_rival_mean_goals_last5          106
away_mean_rank                       106
away_mean_rank_last5                 106
away_rival_mean_rank                 106
away_rival_mean_rank_last5           106
away_mean_points                     106
away_mean_points_last5               106
away_rival_mean_points               106
away_rival_mean_points_last5         106
away_mean_points2rank                106
away_mean_points2rank_last5          106
away_rival_mean_points2rank          106
away_rival_mean_points2rank_last5    106
target                                 0
dtype: int64
# 缺失值处理
fifa_data = fifa_data.dropna().reset_index(drop=True)
fifa_data.isna().sum()
date                                 0
home_team                            0
away_team                            0
rank_home                            0
rank_away                            0
home_score                           0
away_score                           0
result                               0
rank_diff                            0
rank_change_diff                     0
total_points_diff                    0
previous_points_diff                 0
points2rank_diff                     0
home_mean_goals                      0
home_mean_goals_last5                0
home_rival_mean_goals                0
home_rival_mean_goals_last5          0
home_mean_rank                       0
home_mean_rank_last5                 0
home_rival_mean_rank                 0
home_rival_mean_rank_last5           0
home_mean_points                     0
home_mean_points_last5               0
home_rival_mean_points               0
home_rival_mean_points_last5         0
home_mean_points2rank                0
home_mean_points2rank_last5          0
home_rival_mean_points2rank          0
home_rival_mean_points2rank_last5    0
away_mean_goals                      0
away_mean_goals_last5                0
away_rival_mean_goals                0
away_rival_mean_goals_last5          0
away_mean_rank                       0
away_mean_rank_last5                 0
away_rival_mean_rank                 0
away_rival_mean_rank_last5           0
away_mean_points                     0
away_mean_points_last5               0
away_rival_mean_points               0
away_rival_mean_points_last5         0
away_mean_points2rank                0
away_mean_points2rank_last5          0
away_rival_mean_points2rank          0
away_rival_mean_points2rank_last5    0
target                               0
dtype: int64
# 查看数据信息
fifa_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5894 entries, 0 to 5893
Data columns (total 46 columns):
 #   Column                             Non-Null Count  Dtype         
---  ------                             --------------  -----         
 0   date                               5894 non-null   datetime64[ns]
 1   home_team                          5894 non-null   object        
 2   away_team                          5894 non-null   object        
 3   rank_home                          5894 non-null   int64         
 4   rank_away                          5894 non-null   int64         
 5   home_score                         5894 non-null   int64         
 6   away_score                         5894 non-null   int64         
 7   result                             5894 non-null   int64         
 8   rank_diff                          5894 non-null   int64         
 9   rank_change_diff                   5894 non-null   int64         
 10  total_points_diff                  5894 non-null   float64       
 11  previous_points_diff               5894 non-null   float64       
 12  points2rank_diff                   5894 non-null   float64       
 13  home_mean_goals                    5894 non-null   float64       
 14  home_mean_goals_last5              5894 non-null   float64       
 15  home_rival_mean_goals              5894 non-null   float64       
 16  home_rival_mean_goals_last5        5894 non-null   float64       
 17  home_mean_rank                     5894 non-null   float64       
 18  home_mean_rank_last5               5894 non-null   float64       
 19  home_rival_mean_rank               5894 non-null   float64       
 20  home_rival_mean_rank_last5         5894 non-null   float64       
 21  home_mean_points                   5894 non-null   float64       
 22  home_mean_points_last5             5894 non-null   float64       
 23  home_rival_mean_points             5894 non-null   float64       
 24  home_rival_mean_points_last5       5894 non-null   float64       
 25  home_mean_points2rank              5894 non-null   float64       
 26  home_mean_points2rank_last5        5894 non-null   float64       
 27  home_rival_mean_points2rank        5894 non-null   float64       
 28  home_rival_mean_points2rank_last5  5894 non-null   float64       
 29  away_mean_goals                    5894 non-null   float64       
 30  away_mean_goals_last5              5894 non-null   float64       
 31  away_rival_mean_goals              5894 non-null   float64       
 32  away_rival_mean_goals_last5        5894 non-null   float64       
 33  away_mean_rank                     5894 non-null   float64       
 34  away_mean_rank_last5               5894 non-null   float64       
 35  away_rival_mean_rank               5894 non-null   float64       
 36  away_rival_mean_rank_last5         5894 non-null   float64       
 37  away_mean_points                   5894 non-null   float64       
 38  away_mean_points_last5             5894 non-null   float64       
 39  away_rival_mean_points             5894 non-null   float64       
 40  away_rival_mean_points_last5       5894 non-null   float64       
 41  away_mean_points2rank              5894 non-null   float64       
 42  away_mean_points2rank_last5        5894 non-null   float64       
 43  away_rival_mean_points2rank        5894 non-null   float64       
 44  away_rival_mean_points2rank_last5  5894 non-null   float64       
 45  target                             5894 non-null   int64         
dtypes: datetime64[ns](1), float64(35), int64(8), object(2)
memory usage: 2.1+ MB
# 分离数据
data1 = fifa_data[list(fifa_data.columns[8:13].values) + ['target']]
data2 = fifa_data[list(fifa_data.columns[13:29].values) + ['target']]
data3 = fifa_data[fifa_data.columns[29:]]
# 查看数据
data1.tail()
rank_diffrank_change_difftotal_points_diffprevious_points_diffpoints2rank_difftarget
588912-1-62.42-63.26-0.0769231
5890-9-253.0042.930.1111110
589126-6-105.30-133.76-0.0441181
5892205-100.03-76.590.0216451
5893-70-1297.02302.400.0297030
# 查看数据
data2.tail()
home_mean_goalshome_mean_goals_last5home_rival_mean_goalshome_rival_mean_goals_last5home_mean_rankhome_mean_rank_last5home_rival_mean_rankhome_rival_mean_rank_last5home_mean_pointshome_mean_points_last5home_rival_mean_pointshome_rival_mean_points_last5home_mean_points2rankhome_mean_points2rank_last5home_rival_mean_points2rankhome_rival_mean_points2rank_last5target
58891.3108111.80.9864860.445.22973044.661.27027081.61.6756763.01.0675680.00.0531580.1021650.0261650.0000001
58901.3380282.21.3802821.056.47887319.260.74647953.61.2676062.21.4788730.40.0366440.2381730.0339250.0210530
58911.6500000.81.1400000.477.24000069.494.39000060.41.4700001.81.3500001.20.0157800.0341880.0176140.0173911
58922.5094341.60.8113211.242.87735837.281.37735864.21.9905661.40.8018871.40.0297680.0174780.0214970.0381471
58931.5572521.21.0076340.243.22900844.849.96946637.41.7251912.61.0381680.20.0529680.0977190.0303130.0040820
# 查看数据
data3.tail()
away_mean_goalsaway_mean_goals_last5away_rival_mean_goalsaway_rival_mean_goals_last5away_mean_rankaway_mean_rank_last5away_rival_mean_rankaway_rival_mean_rank_last5away_mean_pointsaway_mean_points_last5away_rival_mean_pointsaway_rival_mean_points_last5away_mean_points2rankaway_mean_points2rank_last5away_rival_mean_points2rankaway_rival_mean_points2rank_last5target
58891.4939761.60.8072291.039.26506026.062.43373559.01.7590361.80.8915660.60.0684960.0934120.0283190.0234071
58901.5119051.80.8095240.839.11904826.462.15476242.01.7738102.20.8809520.40.0685960.1071830.0279820.0154070
58912.5142861.40.8095241.242.88571435.681.50476265.61.9809521.00.8095241.60.0296320.0113210.0217020.0440291
58921.0150381.21.4661650.646.75188022.434.66165435.61.0902262.01.6766920.80.0674190.0658480.0447020.0363641
58932.2000003.60.9000000.2101.725000111.0109.125000156.21.9000003.00.9250000.00.0219970.0192610.0101790.0000000
# 小提琴图
standard1 = (data1[:-1] - data1[:-1].mean()) / data1[:-1].std()
standard1['target'] = data1["target"]
violin1 = pd.melt(standard1, id_vars='target', var_name='feature', value_name='value')

standard2 = (data2[:-1] - data2[:-1].mean()) / data2[:-1].std()
standard2['target'] = data2['target']
violin2 = pd.melt(standard2, id_vars='target', var_name='feature', value_name='value')

standard3 = (data3[:-1] - data3[:-1].mean()) / data3[:-1].std()
standard3['target'] = data3['target']
violin3 = pd.melt(standard3, id_vars='target', var_name='feature', value_name='value')
# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin1, split=True, inner='quart')
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VDLPHHQO-1687232982513)(main_files/main_43_0.png)]

# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard1.corr(), annot=True, linewidths=0.2, square=True)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-KKGUw1fF-1687232982513)(main_files/main_44_0.png)]

# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin2, split=True, inner='quart')
plt.xticks(rotation=90)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZribidUS-1687232982513)(main_files/main_45_0.png)]

# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard2.corr(), annot=True, linewidths=0.2, square=True)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sbxm4ITn-1687232982514)(main_files/main_46_0.png)]

# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin3, split=True, inner='quart')
plt.xticks(rotation=90)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-G5sugYMJ-1687232982514)(main_files/main_47_0.png)]

# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard3.corr(), annot=True, linewidths=0.2, square=True)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oZvfKsNr-1687232982514)(main_files/main_48_0.png)]

特征工程

  • mean_goals_diff - 平均进球差异
  • mean_goals_last5_diff - 最近五场平均进球差异
  • rival_mean_goals_diff - 对手平均进球差异
  • rival_mean_goals_last5_diff - 对手最近五场平均进球差异
  • mean_rank_diff - 平均排名差异
  • mean_rank_last5_diff - 最近五场平均排名差异
  • rival_mean_rank_diff - 对手平均排名差异
  • rival_mean_rank_last5_diff - 对手最近五场平均排名差异
  • mean_points_diff - 平均得分差异
  • mean_points_last5_diff - 最近五场平均得分差异
  • rival_mean_points_diff - 对手平均得分差异
  • rival_mean_points_last5_diff - 对手最近五场平均得分差异
  • mean_points2rank_diff - 平均points2rank差异
  • mean_points2rank_last5_diff - 最近五场平均points2rank差异
  • rival_mean_points2rank_diff - 对手平均points2rank差异
  • rival_mean_points2rank_last5_diff - 对手最近五场平均points2rank差异
# 特征工程
data = fifa_data.copy()
data.loc[:, 'mean_goals_diff'] = data['home_mean_goals'] - data['away_mean_goals']
data.loc[:, 'mean_goals_last5_diff'] = data['home_mean_goals_last5'] - data['away_mean_goals_last5']
data.loc[:, 'rival_mean_goals_diff'] = data['home_rival_mean_goals'] - data['away_rival_mean_goals']
data.loc[:, 'rival_mean_goals_last5_diff'] = data['home_rival_mean_goals_last5'] - data['away_rival_mean_goals_last5']
data.loc[:, 'mean_rank_diff'] = data['home_mean_rank'] - data['away_mean_rank']
data.loc[:, 'mean_rank_last5_diff'] = data['home_mean_rank_last5'] - data['away_mean_rank_last5']
data.loc[:, 'rival_mean_rank_diff'] = data['home_rival_mean_rank'] - data['away_rival_mean_rank']
data.loc[:, 'rival_mean_rank_last5_diff'] = data['home_rival_mean_rank_last5'] - data['away_rival_mean_rank_last5']
data.loc[:, 'mean_points_diff'] = data['home_mean_points'] - data['away_mean_points']
data.loc[:, 'mean_points_last5_diff'] = data['home_mean_points_last5'] - data['away_mean_points_last5']
data.loc[:, 'rival_mean_points_diff'] = data['home_rival_mean_points'] - data['away_rival_mean_points']
data.loc[:, 'rival_mean_points_last5_diff'] = data['home_rival_mean_points_last5'] - data['away_rival_mean_points_last5']
data.loc[:, 'mean_points2rank_diff'] = data['home_mean_points2rank'] - data['away_mean_points2rank']
data.loc[:, 'mean_points2rank_last5_diff'] = data['home_mean_points2rank_last5'] - data['away_mean_points2rank_last5']
data.loc[:, 'rival_mean_points2rank_diff'] = data['home_rival_mean_points2rank'] - data['away_rival_mean_points2rank']
data.loc[:, 'rival_mean_points2rank_last5_diff'] = data['home_rival_mean_points2rank_last5'] - data['away_rival_mean_points2rank_last5']
data_diff1 = data.iloc[:, -16:]
standard_diff1 = (data_diff1 - data_diff1.mean()) / data_diff1.std()
standard_diff1['target'] = data['target']
violin_diff1 = pd.melt(standard_diff1, id_vars='target', var_name='feature', value_name='value')
# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin_diff1, split=True, inner='quart')
plt.xticks(rotation=90)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-n0IRvedt-1687232982515)(main_files/main_50_0.png)]

# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard_diff1.corr(), annot=True, linewidths=0.2, square=True)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-76Szt1d2-1687232982515)(main_files/main_51_0.png)]

特征工程

  • mean_goals2mean_rank_diff - 主队平均进球 / 主队平均排名 - 客队平均进球 / 客队平均排名
  • rival_mean_goals2mean_rank_diff - 主队对手平均进球 / 主队平均排名 - 客队对手平均进球 / 客队平均排名
  • mean_goals2mean_rank_last5_diff - 主队最近五场平均进球 / 主队平均排名 - 客队最近五场平均进球 / 客队平均排名
  • rival_mean_goals2mean_rank_last5_diff - 主队对手最近五场平均进球 / 主队平均排名 - 客队对手最近五场平均进球 / 客队平均排名
  • mean_points2mean_rank_diff - 主队平均得分 / 主队平均排名 - 客队平均得分 / 客队平均排名
  • rival_mean_points2mean_rank_diff - 主队对手平均得分 / 主队平均排名 - 客队对手平均得分 / 客队平均排名
  • mean_points2mean_rank_last5_diff - 主队最近五场平均得分 / 主队平均排名 - 客队最近五场平均得分 / 客队平均排名
  • rival_mean_points2mean_rank_last5_diff - 主队对手最近五场平均得分 / 主队平均排名 - 客队对手最近五场平均得分 / 客队平均排名
# 特征工程
data.loc[:, 'mean_goals2mean_rank_diff'] = (data['home_mean_goals'] / data['home_mean_rank']) - (data['away_mean_goals'] / data['away_mean_rank'])
data.loc[:, 'rival_mean_goals2mean_rank_diff'] = (data['home_rival_mean_goals'] / data['home_mean_rank']) - (data['away_rival_mean_goals'] / data['away_mean_rank'])
data.loc[:, 'mean_goals2mean_rank_last5_diff'] = (data['home_mean_goals_last5'] / data['home_mean_rank']) - (data['away_mean_goals_last5'] / data['away_mean_rank'])
data.loc[:, 'rival_mean_goals2mean_rank_last5_diff'] = (data['home_rival_mean_goals_last5'] / data['home_mean_rank']) - (data['away_rival_mean_goals_last5'] / data['away_mean_rank'])
data.loc[:, 'mean_points2mean_rank_diff'] = (data['home_mean_points'] / data['home_mean_rank']) - (data['away_mean_points'] / data['away_mean_rank'])
data.loc[:, 'rival_mean_points2mean_rank_diff'] = (data['home_rival_mean_points'] / data['home_mean_rank']) - (data['away_rival_mean_points'] / data['away_mean_rank'])
data.loc[:, 'mean_points2mean_rank_last5_diff'] = (data['home_mean_points_last5'] / data['home_mean_rank']) - (data['away_mean_points_last5'] / data['away_mean_rank'])
data.loc[:, 'rival_mean_points2mean_rank_last5_diff'] = (data['home_rival_mean_points_last5'] / data['home_mean_rank']) - (data['away_rival_mean_points_last5'] / data['away_mean_rank'])
data_diff2 = data.iloc[:, -8:]
standard_diff2 = (data_diff2 - data_diff2.mean()) / data_diff2.std()
standard_diff2['target'] = data['target']
violin_diff2 = pd.melt(standard_diff2, id_vars='target', var_name='feature', value_name='value')
# 绘制小提琴图
plt.figure(figsize=(15, 10))
sns.violinplot(x='feature', y='value', hue='target', data=violin_diff2, split=True, inner='quart')
plt.xticks(rotation=90)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Wi3FI5HD-1687232982515)(main_files/main_53_0.png)]

# 绘制箱型图
plt.figure(figsize=(15, 10))
sns.boxplot(x='feature', y='value', hue='target', data=violin_diff2)
plt.xticks(rotation=90)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Co0svD98-1687232982516)(main_files/main_54_0.png)]

# 绘制相关性热图
plt.figure(figsize=(16, 16))
sns.heatmap(standard_diff2.corr(), annot=True, linewidths=0.2, square=True)
plt.show()

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-jXDTL6sD-1687232982516)(main_files/main_55_0.png)]

筛选相关性大于0.3的特征

  • rank_diff
  • total_points_diff
  • previous_points_diff
  • away_mean_rank
  • away_mean_rank_last5
  • away_mean_points
  • away_rival_mean_points
  • mean_goals_diff
  • mean_goals_last5_diff
  • rival_mean_goals_diff
  • rival_mean_goals_last5_diff
  • mean_rank_diff
  • mean_rank_last5_diff
  • mean_points_diff
  • mean_points_last5_diff
  • rival_mean_points_diff
  • rival_mean_points_last5_diff
  • mean_points2rank_diff
# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='total_points_diff', y='previous_points_diff', data=data, kind='reg')
plt.show()
<Figure size 1600x1600 with 0 Axes>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Muow1e9n-1687232982516)(main_files/main_57_1.png)]

# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='away_mean_rank', y='away_mean_rank_last5', data=data, kind='reg')
plt.show()
<Figure size 1600x1600 with 0 Axes>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ZRQKzFUl-1687232982517)(main_files/main_58_1.png)]

# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='mean_goals_diff', y='mean_goals_last5_diff', data=data, kind='reg')
plt.show()
<Figure size 1600x1600 with 0 Axes>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-e7gIdzEW-1687232982517)(main_files/main_59_1.png)]

# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='rival_mean_goals_diff', y='rival_mean_goals_last5_diff', data=data, kind='reg')
plt.show()
<Figure size 1600x1600 with 0 Axes>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PXncyckZ-1687232982517)(main_files/main_60_1.png)]

# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='mean_rank_diff', y='mean_rank_last5_diff', data=data, kind='reg')
plt.show()
<Figure size 1600x1600 with 0 Axes>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oz25GiUz-1687232982517)(main_files/main_61_1.png)]

# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='mean_points_diff', y='mean_points_last5_diff', data=data, kind='reg')
plt.show()
<Figure size 1600x1600 with 0 Axes>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vegzj9yv-1687232982518)(main_files/main_62_1.png)]

# 绘制散点图
plt.figure(figsize=(16, 16))
sns.jointplot(x='rival_mean_points_diff', y='rival_mean_points_last5_diff', data=data, kind='reg')
plt.show()
<Figure size 1600x1600 with 0 Axes>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ky2J8sn2-1687232982519)(main_files/main_63_1.png)]

删除分布相似的特征

  • rank_diff
  • total_points_diff
  • away_mean_rank
  • away_mean_rank_last5
  • away_mean_points
  • away_rival_mean_points
  • mean_goals_diff
  • mean_goals_last5_diff
  • rival_mean_goals_diff
  • mean_rank_diff
  • mean_rank_last5_diff
  • mean_points_diff
  • mean_points_last5_diff
  • rival_mean_points_diff
  • rival_mean_points_last5_diff
  • mean_points2rank_diff
# 构建训练数据
fifa_data = data[['home_team', 'away_team', 'target', 'rank_diff', 'total_points_diff', 'away_mean_rank', 'away_mean_rank_last5', 'away_mean_points', 'away_rival_mean_points', 'mean_goals_diff', 'mean_goals_last5_diff', 'rival_mean_goals_diff', 'mean_rank_diff', 'mean_rank_last5_diff', 'mean_points_diff', 'mean_points_last5_diff', 'rival_mean_points_diff', 'rival_mean_points_last5_diff', 'mean_points2rank_diff']]
fifa_data.head()
home_teamaway_teamtargetrank_difftotal_points_diffaway_mean_rankaway_mean_rank_last5away_mean_pointsaway_rival_mean_pointsmean_goals_diffmean_goals_last5_diffrival_mean_goals_diffmean_rank_diffmean_rank_last5_diffmean_points_diffmean_points_last5_diffrival_mean_points_diffrival_mean_points_last5_diffmean_points2rank_diff
0EgyptTogo0-8035.0101.0101.00.03.0-1.0-1.0-2.0-80.0-80.01.01.0-2.0-2.00.009804
1MoroccoBenin0-8628.0127.0127.00.03.01.01.0-5.0-86.0-86.03.03.0-3.0-3.00.035294
2TunisiaEthiopia0-4721.085.085.00.03.05.05.0-1.0-47.0-47.03.03.0-3.0-3.00.023622
3ZimbabweAngola0-4817.0102.0102.01.01.01.01.00.5-48.0-48.01.01.0-0.5-0.5-0.013315
4AlgeriaGhana0-95.039.039.03.00.0-1.0-1.00.0-9.0-9.0-2.0-2.01.01.0-0.020000
# 查看数据信息
fifa_data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5894 entries, 0 to 5893
Data columns (total 19 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   home_team                     5894 non-null   object 
 1   away_team                     5894 non-null   object 
 2   target                        5894 non-null   int64  
 3   rank_diff                     5894 non-null   int64  
 4   total_points_diff             5894 non-null   float64
 5   away_mean_rank                5894 non-null   float64
 6   away_mean_rank_last5          5894 non-null   float64
 7   away_mean_points              5894 non-null   float64
 8   away_rival_mean_points        5894 non-null   float64
 9   mean_goals_diff               5894 non-null   float64
 10  mean_goals_last5_diff         5894 non-null   float64
 11  rival_mean_goals_diff         5894 non-null   float64
 12  mean_rank_diff                5894 non-null   float64
 13  mean_rank_last5_diff          5894 non-null   float64
 14  mean_points_diff              5894 non-null   float64
 15  mean_points_last5_diff        5894 non-null   float64
 16  rival_mean_points_diff        5894 non-null   float64
 17  rival_mean_points_last5_diff  5894 non-null   float64
 18  mean_points2rank_diff         5894 non-null   float64
dtypes: float64(15), int64(2), object(2)
memory usage: 875.0+ KB

模型训练

# 划分数据
X_train, X_test, y_train, y_test = train_test_split(fifa_data.iloc[:, 3:], fifa_data['target'], test_size=0.2, shuffle=True, random_state=2022)

网格搜索是一种穷举搜索方法,它通过遍历超参数的所有可能组合来寻找最优超参数。网格搜索首先为每个超参数设定一组候选值,然后生成这些候选值的笛卡尔积,形成超参数的组合网格。接着,网格搜索会对每个超参数组合进行模型训练和评估,从而找到性能最佳的超参数组合。网格搜索可以保证在指定的参数范围内找到精度最高的参数,因为网格搜索会遍历所有可能参数的组合,在面对大数据集和多参数的情况下会非常耗时。这里仅展示一个组合,如有需要请自行设置超参数候选值,例如:‘max_depth’: [3, 5, 7]。

# 网格搜索
rf_params = {
    'max_depth': [10], 
    'max_features': ['sqrt'], 
    'min_samples_leaf': [10], 
    'min_samples_split': [10], 
    'n_estimators': [100]
}
rf_search = GridSearchCV(RandomForestClassifier(), rf_params, cv=3, n_jobs=-1)
rf_search.fit(X_train, y_train)
rf_search.best_params_
{'max_depth': 10,
 'max_features': 'sqrt',
 'min_samples_leaf': 10,
 'min_samples_split': 10,
 'n_estimators': 100}

随机森林是一种集成算法,它属于Bagging(个体学习器间不存在强依赖关系、可同时生成的并行化方法)类型,通过组合多个弱分类器,最终结果通过投票或取均值,使得整体模型的结果具有较高的精确度和泛化性能。其可以取得不错成绩,主要归功于“随机”和“森林”,一个使它具有抗过拟合能力,一个使它更加精准。

# 模型训练
rf = RandomForestClassifier(max_depth=10, max_features='sqrt', min_samples_leaf=10, min_samples_split=10, n_estimators=100, random_state=2022)
rf.fit(X_train, y_train)
rf_pred = rf.predict(X_test)
rf_acc = accuracy_score(y_test, rf_pred.astype('int'))
joblib.dump(rf, 'rf.pkl')
print('RandomForest Acc is: ', rf_acc)
RandomForest Acc is:  0.732824427480916
# 网格搜索
gbdt_params = {
    'learning_rate': [0.01], 
    'max_depth': [5], 
    'max_features': ['sqrt'], 
    'min_samples_leaf': [10], 
    'min_samples_split': [10], 
    'n_estimators': [500]
} 
gbdt_search = GridSearchCV(GradientBoostingClassifier(), gbdt_params, cv=3, n_jobs=-1)
gbdt_search.fit(X_train, y_train)
gbdt_search.best_params_
{'learning_rate': 0.01,
 'max_depth': 5,
 'max_features': 'sqrt',
 'min_samples_leaf': 10,
 'min_samples_split': 10,
 'n_estimators': 500}

梯度提升决策树(GBDT)是一种集成算法,它属于Boosting(个体学习器间存在强依赖关系、必须串行生成的序列化方法)类型。训练时采用前向分布算法进行贪婪学习,每次迭代都学习一棵CART树来拟合之前 t-1 棵树的预测结果与训练样本真实值的残差。

# 模型训练
gbdt = GradientBoostingClassifier(learning_rate=0.01, max_depth=5, max_features='sqrt', min_samples_leaf=10, min_samples_split=10, n_estimators=500, random_state=2022)
gbdt.fit(X_train, y_train)
gbdt_pred = gbdt.predict(X_test)
gbdt_acc = accuracy_score(y_test, gbdt_pred.astype('int'))
joblib.dump(gbdt, 'gbdt.pkl')
print('GradientBoosting Acc is: ', gbdt_acc)
GradientBoosting Acc is:  0.7430025445292621
# ROC曲线和混淆矩阵
def analyze(model):
    plt.figure(figsize=(15, 10))
    plt.plot([0, 1], [0, 1], 'k--')
    fpr_train, tpr_train, _ = roc_curve(y_train, model.predict_proba(X_train)[:, 1])
    plt.plot(fpr_train, tpr_train, label='train')
    fpr_test, tpr_test, _ = roc_curve(y_test, model.predict_proba(X_test)[:, 1])
    plt.plot(fpr_test, tpr_test, label='test')
    auc_train = roc_auc_score(y_train, model.predict_proba(X_train)[:, 1])
    auc_test = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1])
    plt.legend()
    plt.title('AUC score is %.2f on test and %.2f on train' % (auc_test, auc_train))
    plt.show()
    
    plt.figure(figsize=(15, 10))
    matrix = confusion_matrix(y_test, model.predict(X_test))
    sns.heatmap(matrix, annot=True, linewidths=0.2, fmt='d')
    plt.title('confusion_matrix on test')
    plt.show()
# 绘制ROC曲线和混淆矩阵
analyze(rf)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ChbNPrtc-1687232982519)(main_files/main_77_0.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-oNOLGNh9-1687232982519)(main_files/main_77_1.png)]

# 绘制ROC曲线和混淆矩阵
analyze(gbdt)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UKnJzL3y-1687232982520)(main_files/main_78_0.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-1TvmZyyk-1687232982520)(main_files/main_78_1.png)]

2022世界杯

# 生成特征: 使用历史数据
def get_data(team):
    pasts = data_copy[(data_copy['team'] == team)].sort_values(by=['date'], ascending=False)
    last5 = pasts.head(5)
    rank = pasts['rank'].values[0]
    total_points = pasts['total_points'].values[0]
    mean_rank = pasts['rank'].mean()
    mean_rank_last5 = last5['rank'].mean()
    mean_goals = pasts['score'].mean()
    mean_goals_last5 = last5['score'].mean()
    mean_points = pasts['points'].mean()
    mean_points_last5 = last5['points'].mean()
    mean_points2rank = pasts['points2rank'].mean()
    rival_mean_goals = pasts['rival_score'].mean()
    rival_mean_points = pasts['rival_points'].mean()
    rival_mean_points_last5 = last5['rival_points'].mean()
    return [rank, total_points, mean_rank, mean_rank_last5, mean_goals, mean_goals_last5, mean_points, mean_points_last5, mean_points2rank, rival_mean_goals, rival_mean_points, rival_mean_points_last5]

def get_feature(team1, team2):
    rank_diff = team1[0] - team2[0]
    total_points_diff = team1[1] - team2[1]
    away_mean_rank = team2[2]
    away_mean_rank_last5 = team2[3]
    away_mean_points = team2[6]
    away_rival_mean_points = team2[10]
    mean_goals_diff = team1[4] - team2[4]
    mean_goals_last5_diff = team1[5] - team2[5]
    rival_mean_goals_diff = team1[9] - team2[9]
    mean_rank_diff = team1[2] - team2[2]
    mean_rank_last5_diff = team1[3] - team2[3]
    mean_points_diff = team1[6] - team2[6]
    mean_points_last5_diff = team1[7] - team2[7]
    rival_mean_points_diff = team1[10] - team2[10]
    rival_mean_points_last5_diff = team1[11] - team2[11]
    mean_points2rank_diff = team1[8] - team2[8]
    return [rank_diff, total_points_diff, away_mean_rank, away_mean_rank_last5, away_mean_points, away_rival_mean_points, mean_goals_diff, mean_goals_last5_diff, rival_mean_goals_diff, mean_rank_diff, mean_rank_last5_diff, mean_points_diff, mean_points_last5_diff, rival_mean_points_diff, rival_mean_points_last5_diff, mean_points2rank_diff]
# 读取数据
fifa_2022 = pd.read_csv('/home/aistudio/work/fifa_2022.csv', parse_dates=['date'])
fifa_2022.head()
datehome_teamaway_team
02022-11-20QatarEcuador
12022-11-21SenegalNetherlands
22022-11-21EnglandIran
32022-11-21United StatesWales
42022-11-22ArgentinaSaudi Arabia
# 胜负预测
def predict(teams, model):
    home = teams[0]
    away = teams[1]
    team1 = get_data(home)
    team2 = get_data(away)
    feature1 = get_feature(team1, team2)
    feature2 = get_feature(team2, team1)
    proba1 = model.predict_proba([feature1])
    proba2 = model.predict_proba([feature2])
    pred1 = (proba1[0][0] + proba2[0][1]) / 2
    pred2 = (proba2[0][0] + proba1[0][1]) / 2
    if pred1 < pred2:
        print('%s VS %s: %s获胜 概率: %.2f' % (home, away, away, pred2))
    else:
        print('%s VS %s: %s获胜 概率: %.2f' % (home, away, home, pred1))
# 2022世界杯
game8 = fifa_2022.iloc[-16:-8, 1:]
game4 = fifa_2022.iloc[-8:-4, 1:]
game2 = fifa_2022.iloc[-4:-2, 1:]
game1 = fifa_2022.iloc[-2:, 1:]
team8 = []
team4 = []
team2 = []
team1 = []
for idx, row in game8.iterrows():
    home_team = row['home_team']
    away_team = row['away_team']
    team8.append([home_team, away_team])
for idx, row in game4.iterrows():
    home_team = row['home_team']
    away_team = row['away_team']
    team4.append([home_team, away_team])
for idx, row in game2.iterrows():
    home_team = row['home_team']
    away_team = row['away_team']
    team2.append([home_team, away_team])
for idx, row in game1.iterrows():
    home_team = row['home_team']
    away_team = row['away_team']
    team1.append([home_team, away_team])
# 1/8决赛
for teams in team8:
    predict(teams, gbdt)
Netherlands VS United States: Netherlands获胜 概率: 0.66
Argentina VS Australia: Argentina获胜 概率: 0.83
France VS Poland: France获胜 概率: 0.71
England VS Senegal: England获胜 概率: 0.65
Japan VS Croatia: Croatia获胜 概率: 0.65
Brazil VS South Korea: Brazil获胜 概率: 0.82
Morocco VS Spain: Spain获胜 概率: 0.81
Portugal VS Switzerland: Portugal获胜 概率: 0.52
# 1/4决赛
for teams in team4:
    predict(teams, gbdt)
Croatia VS Brazil: Brazil获胜 概率: 0.70
Netherlands VS Argentina: Argentina获胜 概率: 0.60
Morocco VS Portugal: Portugal获胜 概率: 0.76
England VS France: France获胜 概率: 0.55
# 半决赛
for teams in team2:
    predict(teams, gbdt)
Argentina VS Croatia: Argentina获胜 概率: 0.64
France VS Morocco: France获胜 概率: 0.73
# 决赛
for teams in team1:
    predict(teams, gbdt)
Croatia VS Morocco: Croatia获胜 概率: 0.67
Argentina VS France: Argentina获胜 概率: 0.55

VS Morocco: France获胜 概率: 0.73

# 决赛
for teams in team1:
    predict(teams, gbdt)
Croatia VS Morocco: Croatia获胜 概率: 0.67
Argentina VS France: Argentina获胜 概率: 0.55

总结

2022年12月19日,2022年卡塔尔世界杯决赛,阿根廷队在点球大战中战胜法国队,获得冠军。

项目以学习为目的,旨在体验特征工程。优化:数据扩充、数据粒度、特征工程、模型构建。

致谢

Predicting FIFA 2022 World Cup with ML

此文章为搬运
原项目链接

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值