练习网址:https://www.kaggle.com/chirag02/fifa19-player-wage-prediction/notebook
代码练习:
# -*- coding: UTF-8 -*-
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sb
import sys
import io
sys.stdout = io.TextIOWrapper( sys.stdout.buffer, encoding='utf8')
df = pd.read_csv("data.csv")
# pd.set_option("display.max_columns",None)
# pd.set_option("display.max_rows",None)
# pd.set_option('max_colwidth',1000)
#.iloc根据标签索引值来选取值,打印df的前5行,前3列。head()默认打印前5行
# print(df.head().iloc[:,:4])
# 删除某列数据,implace=True改变原始数据
df.drop(columns=["Unnamed: 0"],inplace=True)
# 返回随机样本
# print(df.sample(n=5))
# print(df.describe())
bins = np.arange(df["Overall"].min(),df["Overall"].max()+1,1)
# print(bins)
# print(df["Overall"])
# print(df.groupby(["Overall"]).size()[70])
#设置图形显示大小
# plt.figure(figsize=[8,5])
# # 绘制直方图,Overall分布情况
# plt.hist(df["Overall"],bins=bins)
# plt.title("Overall Rating Distribution")
# plt.xlabel("Mean Overall Rating")
# plt.ylabel("Count")
# plt.show()
# plt.figure(figsize=[16,5])
# plt.suptitle("Overall Rating Vs Age",fontsize=16)
# plt.subplot(1,2,1)
# bin_x = np.arange(df["Age"].min(),df["Age"].max()+1,1)
# bin_y = np.arange(df["Overall"].min(),df["Overall"].max()+2,2)
# plt.hist2d(x = df["Age"],y=df["Overall"],cmap="YlGnBu",bins=[bin_x,bin_y])
# plt.colorbar()
# plt.xlabel("Age(years)")
# plt.ylabel("Overall Rating")
# plt.subplot(1,2,2)
# plt.scatter(x=df["Age"],y=df["Overall"],alpha=0.25,marker=".")
# plt.xlabel("Age(years)")
# plt.ylabel("Overall")
# plt.show()
# sb.jointplot(x=df.Overall,y=df.Potential,kind="kde")
# plt.show()
# plt.figure(figsize=[8,5])
# plt.scatter(x=df.Overall,y=df.Potential,c=df.Age,alpha=0.25,cmap="rainbow")
# plt.colorbar().set_label("Age")
# plt.xlabel("Overall Rating")
# plt.ylabel("Potential")
# plt.show()
df_opa = df[["ID","Name","Age","Overall","Potential","Value","Wage"]]
# print(df_opa.head())
def currencystrtoint(amount):
new_amount=[]
for s in amount:
list(s)
abbr = s[-1]
if abbr is "M":
s = s[1:-1]
s = float("".join(s))
s *= 1000000
elif abbr is "K":
s = s[1:-1]
s = float("".join(s))
s *= 1000
else:
s = 0
new_amount.append(s)
return new_amount
df_opa["Value"] = currencystrtoint(list(df_opa["Value"]))
df_opa["Wage"] = currencystrtoint(list(df_opa["Wage"]))
# print(df_opa.describe())
# sb.pairplot(df_opa)
# plt.show()
# sb.lmplot(data=df_opa,x="Overall",y="Wage",order=2,scatter_kws={"alpha":0.3,"color":"y"})
# plt.show()
'''选择最佳度来预测工资'''
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn import metrics
x = df_opa[["Age","Overall","Potential","Value"]]
y = df_opa["Wage"]
Xtrain,Xtest,ytrain,ytest = train_test_split(x,y,test_size=0.3,random_state=101)
def pred_wage(degree,Xtrain,Xtest,ytrain):
if degree > 1:
poly = PolynomialFeatures(degree = degree)
Xtrain = poly.fit_transform(Xtrain)
Xtest = poly.fit_transform(Xtest)
lm = LinearRegression()
lm.fit(Xtrain,ytrain)
wages = lm.predict(Xtest)
return wages
MAE,MSE,RMSE = [],[],[]
for i in range(1,11):
predicted_wages = pred_wage(i,Xtrain,Xtest,ytrain)
# metrics.mean_absolute_error求平均绝对误差回归损失
MAE.append(metrics.mean_absolute_error(ytest,predicted_wages))
# metrics.mean_squared_error求均方平均误差回归损失
MSE.append(metrics.mean_squared_error(ytest,predicted_wages))
# 均方平均误差回归损失的平方根
RMSE.append(np.sqrt(metrics.mean_squared_error(ytest,predicted_wages)))
# print(MAE)
# plt.figure(figsize=[11,8])
# plt.subplot(2,2,1)
# plt.plot(MAE,color="red")
# plt.xlabel("Degree")
# plt.ylabel("Mean Absolute Error")
# plt.subplot(2,2,2)
# plt.plot(MSE,color="green")
# plt.xlabel("Degree")
# plt.ylabel("Mean Squared Error")
# plt.subplot(2,2,3)
# plt.plot(RMSE,color="yellow")
# plt.xlabel("Dregree")
# plt.ylabel("Root Mean Squared Error")
# plt.show()
'''degre=2'''
predicted_wages = pred_wage(2,Xtrain,Xtest,ytrain)
# sb.regplot(ytest,predicted_wages,scatter_kws={"alpha":0.3,"color":"y"})
# plt.xlabel("Actual Wage")
# plt.ylabel("Predicted Wage")
# plt.show()
sb.distplot(ytest-predicted_wages)
plt.axis([-50000,50000,0,0.00016])
plt.show()