飞桨学习赛:英雄联盟大师预测2023年2月85.365分方案

★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>


飞桨学习赛:英雄联盟大师预测2023年2月85.365分方案

比赛链接 - 飞桨学习赛:英雄联盟大师预测

赛题介绍

本赛题属于典型的分类问题,以英雄联盟手游为背景,要求选手根据英雄联盟玩家的实时游戏数据,预测玩家在本局游戏中的输赢情况。

数据说明

数据集中每一行为一个玩家的游戏数据,数据字段如下所示:

  • id:玩家记录id
  • win:是否胜利,标签变量
  • kills:击杀次数
  • deaths:死亡次数
  • assists:助攻次数
  • largestkillingspree:最大 killing spree(游戏术语,意味大杀特杀。当你连续杀死三个对方英雄而中途没有死亡时)
  • largestmultikill:最大mult ikill(游戏术语,短时间内多重击杀)
  • longesttimespentliving:最长存活时间
  • doublekills:doublekills次数
  • triplekills:doublekills次数
  • quadrakills:quadrakills次数
  • pentakills:pentakills次数
  • totdmgdealt:总伤害
  • magicdmgdealt:魔法伤害
  • physicaldmgdealt:物理伤害
  • truedmgdealt:真实伤害
  • largestcrit:最大暴击伤害
  • totdmgtochamp:对对方玩家的伤害
  • magicdmgtochamp:对对方玩家的魔法伤害
  • physdmgtochamp:对对方玩家的物理伤害
  • truedmgtochamp:对对方玩家的真实伤害
  • totheal:治疗量
  • totunitshealed:痊愈的总单位
  • dmgtoturrets:对炮塔的伤害
  • timecc:法控时间
  • totdmgtaken:承受的伤害
  • magicdmgtaken:承受的魔法伤害
  • physdmgtaken:承受的物理伤害
  • truedmgtaken:承受的真实伤害
  • wardsplaced:侦查守卫放置次数
  • wardskilled:侦查守卫摧毁次数
  • firstblood:是否为firstblood

测试集中label字段win为空,需要选手预测。

比赛难点

本次比赛难点主要在于数据处理和模型选择上

  • 比赛数据既没有缺失值也没有特别明显的异常值,数据分布比较正常,特征工程上找不到很多能增长准确率的地方
  • 事实上就算不做任何数据处理直接用随机森例、LightGBM和xgboost模型都能够非常轻松的拿到82-84分神经网络模型也并不能够有很大程度的进步甚至不如树模型

项目亮点

  • 采用遗传算法构造了更多特征并利用PCA算法进行特征降维

  • 利用Dense Connetion思想设计网络结构防止梯度消失

代码仓库

https://github.com/ZhangzrJerry/paddle-lolmp

  1. 解压数据
!unzip -d data/ data/data137276/test.csv.zip
!unzip -d data/ data/data137276/train.csv.zip
!pip install -r requirements.txt  -i https://pypi.tuna.tsinghua.edu.cn/simple/
!mkdir result pretrain
Archive:  data/data137276/test.csv.zip
  inflating: data/test.csv           
Archive:  data/data137276/train.csv.zip
  inflating: data/train.csv          
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/
Collecting category_encoders
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8b/9f/09d149aab1296254fe83e34aaddc59ade820acb15d529773d753df3384bf/category_encoders-2.6.0-py2.py3-none-any.whl (81 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.2/81.2 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hRequirement already satisfied: seaborn in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.10.0)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (1.19.5)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.1.5)
Requirement already satisfied: sklearn in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (0.0)
Collecting gplearn
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/cc/b0/063b2ddfd9258c4f43abd3b5d13ca94b53e9479a3f21df18fafe3948b67d/gplearn-0.4.2-py3-none-any.whl (25 kB)
Requirement already satisfied: IPython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (7.34.0)
Collecting pydotplus
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/60/bf/62567830b700d9f6930e9ab6831d6ba256f7b0b730acb37278b0ccdffacf/pydotplus-2.0.2.tar.gz (278 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.7/278.7 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h  Preparing metadata (setup.py) ... [?25ldone
[?25hRequirement already satisfied: scikit-learn>=0.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from category_encoders->-r requirements.txt (line 1)) (0.24.2)
Collecting statsmodels>=0.9.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/91/8e/062b268b8e6d19382cbf2f97ac0384285511790718ce90bbfb1eb5e44b07/statsmodels-0.13.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hRequirement already satisfied: scipy>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from category_encoders->-r requirements.txt (line 1)) (1.6.3)
Collecting patsy>=0.5.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/e4/b3263b0e353f2be7b14f044d57874490c9cef1798a435f038683acea5c98/patsy-0.5.3-py2.py3-none-any.whl (233 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.8/233.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hRequirement already satisfied: matplotlib>=2.1.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seaborn->-r requirements.txt (line 2)) (2.2.3)
Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas->-r requirements.txt (line 4)) (2.8.2)
Requirement already satisfied: pytz>=2017.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas->-r requirements.txt (line 4)) (2019.3)
Collecting scikit-learn>=0.20.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/bd/05/e561bc99a615b5c099c7a9355409e5e57c525a108f1c2e156abb005b90a6/scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (24.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.8/24.8 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting joblib>=1.0.0
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/91/d4/3b4c8e5a30604df4c7518c562d4bf0502f2fa29221459226e140cf846512/joblib-1.2.0-py3-none-any.whl (297 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m298.0/298.0 kB[0m [31m682.1 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hRequirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (4.4.2)
Requirement already satisfied: pickleshare in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.7.5)
Requirement already satisfied: jedi>=0.16 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.17.2)
Requirement already satisfied: pexpect>4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (4.7.0)
Requirement already satisfied: traitlets>=4.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (5.4.0)
Requirement already satisfied: setuptools>=18.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (56.2.0)
Requirement already satisfied: backcall in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.1.0)
Requirement already satisfied: pygments in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (2.13.0)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (2.0.10)
Requirement already satisfied: matplotlib-inline in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.1.6)
Requirement already satisfied: pyparsing>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pydotplus->-r requirements.txt (line 8)) (3.0.9)
Requirement already satisfied: parso<0.8.0,>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from jedi>=0.16->IPython->-r requirements.txt (line 7)) (0.7.1)
Requirement already satisfied: six>=1.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (1.16.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (1.1.0)
Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pexpect>4.3->IPython->-r requirements.txt (line 7)) (0.7.0)
Requirement already satisfied: wcwidth in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->IPython->-r requirements.txt (line 7)) (0.1.7)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.20.0->category_encoders->-r requirements.txt (line 1)) (2.1.0)
Requirement already satisfied: packaging>=21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from statsmodels>=0.9.0->category_encoders->-r requirements.txt (line 1)) (21.3)
Building wheels for collected packages: pydotplus
  Building wheel for pydotplus (setup.py) ... [?25ldone
[?25h  Created wheel for pydotplus: filename=pydotplus-2.0.2-py3-none-any.whl size=24566 sha256=6ae026c23d9c6a62321edd556da9ca4e035a4bc23b4c36f5cc1abb48cd91873b
  Stored in directory: /home/aistudio/.cache/pip/wheels/5c/0c/20/fd91edec2a19961da82914c46e465380335e60bc30253db979
Successfully built pydotplus
Installing collected packages: pydotplus, patsy, joblib, scikit-learn, statsmodels, gplearn, category_encoders
  Attempting uninstall: joblib
    Found existing installation: joblib 0.14.1
    Uninstalling joblib-0.14.1:
      Successfully uninstalled joblib-0.14.1
  Attempting uninstall: scikit-learn
    Found existing installation: scikit-learn 0.24.2
    Uninstalling scikit-learn-0.24.2:
      Successfully uninstalled scikit-learn-0.24.2
Successfully installed category_encoders-2.6.0 gplearn-0.4.2 joblib-1.2.0 patsy-0.5.3 pydotplus-2.0.2 scikit-learn-1.0.2 statsmodels-0.13.5

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.1.2[0m[39;49m -> [0m[32;49m23.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
import numpy as np
import pandas as pd

x_train = pd.read_csv('data/train.csv').drop('win',axis=1)
y_train = pd.read_csv('data/train.csv')['win']
x_test = pd.read_csv('data/test.csv')

# 拼接后方便特征工程
feature = pd.concat([x_train, x_test])
  1. 数据可视化
import matplotlib.pyplot as plt
import seaborn as sns

# 查看数据的信息
print(feature.describe())
print(feature.info())

# 查看数据分布特征
sns.FacetGrid(pd.melt(feature), col="variable",  col_wrap=4, sharex=False, sharey=False).map(sns.distplot, "value")

# 查看数据相关性
sns.set_context({"figure.figsize":(8,8)})
sns.heatmap(data=feature.corr(),square=True,cmap='RdBu_r')
                  id          kills         deaths        assists  \
count  200000.000000  200000.000000  200000.000000  200000.000000   
mean    99999.500000       5.798545       5.810190       8.322870   
std     57735.171256       4.605316       3.263815       5.933893   
min         0.000000       0.000000       0.000000       0.000000   
25%     49999.750000       2.000000       3.000000       4.000000   
50%     99999.500000       5.000000       6.000000       7.000000   
75%    149999.250000       8.000000       8.000000      12.000000   
max    199999.000000      39.000000      23.000000      52.000000   

       largestkillingspree  largestmultikill  longesttimespentliving  \
count        200000.000000     200000.000000           200000.000000   
mean              2.671450          1.332095              630.531655   
std               2.537784          0.758037              311.568408   
min               0.000000          0.000000                0.000000   
25%               0.000000          1.000000              433.000000   
50%               2.000000          1.000000              590.000000   
75%               4.000000          2.000000              792.000000   
max              31.000000          5.000000             3038.000000   

         doublekills    triplekills    quadrakills  ...  totunitshealed  \
count  200000.000000  200000.000000  200000.000000  ...   200000.000000   
mean        0.540265       0.073220       0.010235  ...        2.253135   
std         0.924831       0.295887       0.104548  ...        2.481890   
min         0.000000       0.000000       0.000000  ...        0.000000   
25%         0.000000       0.000000       0.000000  ...        1.000000   
50%         0.000000       0.000000       0.000000  ...        1.000000   
75%         1.000000       0.000000       0.000000  ...        3.000000   
max        11.000000       7.000000       4.000000  ...       98.000000   

        dmgtoturrets    timecc    totdmgtaken  magicdmgtaken   physdmgtaken  \
count  200000.000000  200000.0  200000.000000  200000.000000  200000.000000   
mean     2138.209810       0.0   23226.733180    8136.551010   14039.533350   
std      2934.306106       0.0   11873.669826    5161.055339    7754.110833   
min         0.000000       0.0       0.000000       0.000000       0.000000   
25%         0.000000       0.0   15264.000000    4521.000000    8627.000000   
50%       986.000000       0.0   21531.000000    7246.000000   12803.000000   
75%      3222.250000       0.0   29465.250000   10739.000000   18205.000000   
max     55083.000000       0.0  118130.000000   71631.000000   73172.000000   

        truedmgtaken    wardsplaced    wardskilled     firstblood  
count  200000.000000  200000.000000  200000.000000  200000.000000  
mean     1049.892170      11.508290       1.782860       0.100380  
std      1266.146212       7.539761       2.226049       0.300507  
min         0.000000       0.000000       0.000000       0.000000  
25%       274.000000       7.000000       0.000000       0.000000  
50%       656.000000      10.000000       1.000000       0.000000  
75%      1352.000000      14.000000       3.000000       0.000000  
max     25140.000000     322.000000      48.000000       1.000000  

[8 rows x 31 columns]
<class 'pandas.core.frame.DataFrame'>
Int64Index: 200000 entries, 0 to 19999
Data columns (total 31 columns):
 #   Column                  Non-Null Count   Dtype
---  ------                  --------------   -----
 0   id                      200000 non-null  int64
 1   kills                   200000 non-null  int64
 2   deaths                  200000 non-null  int64
 3   assists                 200000 non-null  int64
 4   largestkillingspree     200000 non-null  int64
 5   largestmultikill        200000 non-null  int64
 6   longesttimespentliving  200000 non-null  int64
 7   doublekills             200000 non-null  int64
 8   triplekills             200000 non-null  int64
 9   quadrakills             200000 non-null  int64
 10  pentakills              200000 non-null  int64
 11  totdmgdealt             200000 non-null  int64
 12  magicdmgdealt           200000 non-null  int64
 13  physicaldmgdealt        200000 non-null  int64
 14  truedmgdealt            200000 non-null  int64
 15  largestcrit             200000 non-null  int64
 16  totdmgtochamp           200000 non-null  int64
 17  magicdmgtochamp         200000 non-null  int64
 18  physdmgtochamp          200000 non-null  int64
 19  truedmgtochamp          200000 non-null  int64
 20  totheal                 200000 non-null  int64
 21  totunitshealed          200000 non-null  int64
 22  dmgtoturrets            200000 non-null  int64
 23  timecc                  200000 non-null  int64
 24  totdmgtaken             200000 non-null  int64
 25  magicdmgtaken           200000 non-null  int64
 26  physdmgtaken            200000 non-null  int64
 27  truedmgtaken            200000 non-null  int64
 28  wardsplaced             200000 non-null  int64
 29  wardskilled             200000 non-null  int64
 30  firstblood              200000 non-null  int64
dtypes: int64(31)
memory usage: 48.8 MB
None


/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/seaborn/distributions.py:288: UserWarning: Data must have variance to compute a kernel density estimate.
  warnings.warn(msg, UserWarning)





<matplotlib.axes._subplots.AxesSubplot at 0x7f6b5a36cf50>

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sONywZGi-1677308129213)(main_files/main_5_3.png)]

  1. 数据预处理
from category_encoders import *
from sklearn.preprocessing import MinMaxScaler

# 对离散变量顺序编码
columns= ['kills','deaths','assists', 'largestkillingspree',
    'largestmultikill', 'longesttimespentliving', 'doublekills',
    'triplekills', 'quadrakills', 'pentakills','firstblood']
m_feature = OrdinalEncoder(cols=columns).fit(feature).transform(feature)
m_feature.columns=[
    'id', 'akills', 'adeaths', 'aassists', 'alargestkillingspree',
       'alargestmultikill', 'alongesttimespentliving', 'adoublekills',
       'atriplekills', 'aquadrakills', 'apentakills', 'totdmgdealt',
       'magicdmgdealt', 'physicaldmgdealt', 'truedmgdealt', 'largestcrit',
       'totdmgtochamp', 'magicdmgtochamp', 'physdmgtochamp', 'truedmgtochamp',
       'totheal', 'totunitshealed', 'dmgtoturrets', 'timecc', 'totdmgtaken',
       'magicdmgtaken', 'physdmgtaken', 'truedmgtaken', 'wardsplaced',
       'wardskilled', 'afirstblood'
]
m_feature=pd.concat([m_feature,feature[columns]],axis=1).drop(['timecc','id'],axis=1)

# 对特征归一化
sc = MinMaxScaler().fit(m_feature)
feature = sc.transform(m_feature)

pd.DataFrame(feature).to_csv('pretrain/feature.csv',index=False)

pd.DataFrame(feature[:180000]).to_csv('pretrain/train_feature.csv',index=False)
pd.DataFrame(feature[180000:]).to_csv('pretrain/test_feature.csv',index=False)
  1. 自动挖掘特征
from gplearn.genetic import SymbolicTransformer
from IPython.display import Image
import pydotplus
import numpy as np
import pandas as pd

function_set = ['add', 'sub', 'mul', 'div', 'log', 'sqrt', 'abs', 'neg', 'max', 'min']

st = SymbolicTransformer(
    generations=20,
    population_size=1000,
    hall_of_fame=100,
    n_components=100,
    function_set=function_set,
    parsimony_coefficient=0.0005,
    max_samples=0.9,
    verbose=1,
    random_state=0,
    n_jobs=3
)

st.fit(np.array(pd.read_csv('pretrain/train_feature.csv')),np.array(pd.read_csv('data/train.csv')['win']))

graph = st._best_programs[0].export_graphviz()
graph = pydotplus.graphviz.graph_from_dot_data(graph)
display(Image(graph.create_png()))

pd.DataFrame(
    np.concatenate(
        [
            st.transform(
                np.array(
                    pd.read_csv('pretrain/train_feature.csv')
                )
            ),
            np.array(
                pd.read_csv('pretrain/train_feature.csv')
            )
        ],
        axis=1
    )
).to_csv('pretrain/gptrain.csv',index=False)

pd.DataFrame(
    np.concatenate(
        [
            st.transform(
                np.array(
                    pd.read_csv('pretrain/test_feature.csv')
                )
            ),
            np.array(
                pd.read_csv('pretrain/test_feature.csv')
            )
        ],
        axis=1
    )
).to_csv('pretrain/gptest.csv',index=False)
    |   Population Average    |             Best Individual              |
---- ------------------------- ------------------------------------------ ----------
 Gen   Length          Fitness   Length          Fitness      OOB Fitness  Time Left
   0    12.03        0.0967885        4         0.395269         0.413457      2.82m
   1     6.91         0.250165        5         0.512841         0.518938      1.67m
   2     4.82         0.322645        7         0.600911         0.603908      1.49m
   3     5.58         0.365057       12         0.604385         0.606605      1.34m
   4     6.48         0.414963        9         0.615257         0.610385      1.35m
   5     8.58         0.468763        9         0.628536         0.620473      1.22m
   6    11.07         0.508876       15         0.638086         0.641776      1.18m
   7    13.61         0.514324       11         0.640324         0.636264      1.12m
   8    14.01         0.540671       39         0.656748         0.657258      1.09m
   9    15.80         0.544383       39         0.656818         0.656685      1.02m
  10    18.48         0.556065       41         0.656622         0.658435     54.90s
  11    19.58         0.569885       50          0.66438         0.666361     49.46s
  12    20.93          0.58614       39         0.665209         0.665554     46.77s
  13    22.09         0.589301       53         0.666591          0.67299     38.94s
  14    23.42         0.597276       59         0.666831         0.671305     33.80s
  15    25.60         0.600218       65          0.66804         0.668831     28.90s
  16    27.89         0.618124       76         0.667599         0.668118     22.03s
  17    31.54           0.6234       62         0.668819         0.660909     15.79s
  18    33.64         0.631025       55         0.669813         0.669861      7.87s
  19    32.85         0.625123       52         0.672489         0.662052      0.00s

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Q5hdKQun-1677308129215)(main_files/main_9_1.png)]

  1. 特征降维
import pandas as pd
from sklearn.decomposition import PCA

traindata = pd.read_csv('pretrain/gptrain.csv')
testdata = pd.read_csv('pretrain/gptest.csv')

transfer = PCA(n_components=64)
pd.DataFrame(
    transfer.fit_transform(traindata)
).to_csv('pretrain/pcatrain.csv',index=False)
pd.DataFrame(
    transfer.transform(testdata)
).to_csv('pretrain/pcatest.csv',index=False)
  1. 模型搭建
from paddle import nn
import paddle
import paddle.nn.functional as F

class DenseBlock(nn.Layer):
    def __init__(self):
        super(DenseBlock, self).__init__()
        self.fc1 = nn.Linear(96, 64)
        self.fc2 = nn.Linear(64, 32)
        pass
    def forward(self, input, dense):
        x = self.fc1(input)
        x = self.fc2(x)
        x = F.relu(x)
        x = paddle.concat([x, dense],axis=1)
        return x
    pass

class MyNet(nn.Layer):
    def __init__(self):
        super(MyNet, self).__init__()
        self.fc0 = nn.Linear(64, 96)
        self.bl1 = DenseBlock()
        self.bl2 = DenseBlock()
        self.bl3 = DenseBlock()
        self.bl4 = DenseBlock()
        self.bl5 = DenseBlock()
        self.bl6 = DenseBlock()
        self.bl7 = DenseBlock()
        self.bl8 = DenseBlock()
        self.bl9 = DenseBlock()
        self.fc4 = nn.Linear(96, 32)
        self.fc5 = nn.Linear(32, 1)
        pass
    def forward(self, input):
        x = self.fc0(input)
        x = self.bl1(x, input)
        x = self.bl2(x, input)
        x = self.bl3(x, input)
        x = self.bl4(x, input)
        x = self.bl5(x, input)
        x = self.bl6(x, input)
        x = self.bl7(x, input)
        x = self.bl8(x, input)
        x = self.bl9(x, input)
        x = self.fc4(x)
        x = self.fc5(x)
        return F.sigmoid(x)
    pass

paddle.summary(MyNet(),(180000,64))
W0216 17:52:57.241905  5934 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0216 17:52:57.244976  5934 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.


-----------------------------------------------------------------------------------
 Layer (type)           Input Shape              Output Shape         Param #    
===================================================================================
   Linear-1            [[180000, 64]]            [180000, 96]          6,240     
   Linear-2            [[180000, 96]]            [180000, 64]          6,208     
   Linear-3            [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-1   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-4            [[180000, 96]]            [180000, 64]          6,208     
   Linear-5            [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-2   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-6            [[180000, 96]]            [180000, 64]          6,208     
   Linear-7            [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-3   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-8            [[180000, 96]]            [180000, 64]          6,208     
   Linear-9            [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-4   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-10           [[180000, 96]]            [180000, 64]          6,208     
   Linear-11           [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-5   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-12           [[180000, 96]]            [180000, 64]          6,208     
   Linear-13           [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-6   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-14           [[180000, 96]]            [180000, 64]          6,208     
   Linear-15           [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-7   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-16           [[180000, 96]]            [180000, 64]          6,208     
   Linear-17           [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-8   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-18           [[180000, 96]]            [180000, 64]          6,208     
   Linear-19           [[180000, 64]]            [180000, 32]          2,080     
 DenseBlock-9   [[180000, 96], [180000, 64]]     [180000, 96]            0       
   Linear-20           [[180000, 96]]            [180000, 32]          3,104     
   Linear-21           [[180000, 32]]            [180000, 1]            33       
===================================================================================
Total params: 83,969
Trainable params: 83,969
Non-trainable params: 0
-----------------------------------------------------------------------------------
Input size (MB): 43.95
Forward/backward pass size (MB): 2550.20
Params size (MB): 0.32
Estimated Total Size (MB): 2594.47
-----------------------------------------------------------------------------------






{'total_params': 83969, 'trainable_params': 83969}
  1. 模型训练
import paddle.nn.functional as F
import paddle
import pandas as pd
import numpy as np

def train_pm(model, optimizer, feature, label, epoches=1):
    # 开启0号GPU训练
    paddle.device.set_device('gpu:0')

    print('start training ')
    model.train()    
    for epoch in range(epoches):
        feature = paddle.to_tensor(feature)
        label = paddle.reshape(paddle.to_tensor(label),(-1,1))
        # 运行模型前向计算,得到预测值
        logits = model(feature)
        loss = F.binary_cross_entropy_with_logits(logits, label)
        avg_loss = paddle.mean(loss)
        # 反向传播,更新权重,清除梯度
        avg_loss.backward()
        optimizer.step()
        optimizer.clear_grad()

        model.eval()
        accuracies = []
        losses = []

        # 运行模型前向计算,得到预测值
        logits = model(feature)
        # 二分类,sigmoid计算后的结果以0.5为阈值分两个类别
        # 计算sigmoid后的预测概率,进行loss计算
        pred = logits
        loss = F.binary_cross_entropy_with_logits(logits, label)
        # 计算预测概率小于0.5的类别
        pred2 = pred * (-1.0) + 1.0
        # 得到两个类别的预测概率,并沿第一个维度级联
        pred = paddle.concat([pred2, pred], axis=1)
        acc = paddle.metric.accuracy(pred, paddle.cast(label, dtype='int64'))

        accuracies.append(acc.numpy())
        losses.append(loss.numpy())

        print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))
        model.train()

        # paddle.save(model.state_dict(), 'model/lolmp{}_{}.pdparams'.format(epoch,acc.numpy()))


mynet = MyNet()
train_pm(
    mynet,
    paddle.optimizer.Adam(
        parameters=mynet.parameters(),
        learning_rate=0.005
    ),
    np.array(
        pd.read_csv('pretrain/pcatrain.csv')
    ).astype('float32'),
    np.array(
        pd.read_csv('data/train.csv')['win']
    ).astype('float32'),
    400
)
start training 
[validation] accuracy/loss: 0.5128/0.6676
[validation] accuracy/loss: 0.8079/0.6187
[validation] accuracy/loss: 0.8114/0.5981
[validation] accuracy/loss: 0.8112/0.5925
[validation] accuracy/loss: 0.8110/0.5919
[validation] accuracy/loss: 0.8105/0.5920
[validation] accuracy/loss: 0.8106/0.5913
[validation] accuracy/loss: 0.7105/0.6138
[validation] accuracy/loss: 0.8112/0.5916
[validation] accuracy/loss: 0.8116/0.5929
[validation] accuracy/loss: 0.8118/0.5949
[validation] accuracy/loss: 0.5492/0.7224
[validation] accuracy/loss: 0.8119/0.5934
[validation] accuracy/loss: 0.8120/0.5932
[validation] accuracy/loss: 0.8122/0.5930
[validation] accuracy/loss: 0.8125/0.5928
[validation] accuracy/loss: 0.8127/0.5923
[validation] accuracy/loss: 0.8131/0.5917
[validation] accuracy/loss: 0.8129/0.5909
[validation] accuracy/loss: 0.8123/0.5899
[validation] accuracy/loss: 0.8070/0.5939
[validation] accuracy/loss: 0.8122/0.5894
[validation] accuracy/loss: 0.8127/0.5894
[validation] accuracy/loss: 0.8132/0.5894
[validation] accuracy/loss: 0.8134/0.5892
[validation] accuracy/loss: 0.8135/0.5890
[validation] accuracy/loss: 0.8138/0.5888
[validation] accuracy/loss: 0.8139/0.5885
[validation] accuracy/loss: 0.8141/0.5881
[validation] accuracy/loss: 0.8141/0.5877
[validation] accuracy/loss: 0.8143/0.5874
[validation] accuracy/loss: 0.8134/0.5871
[validation] accuracy/loss: 0.8123/0.5870
[validation] accuracy/loss: 0.8115/0.5866
[validation] accuracy/loss: 0.8118/0.5860
[validation] accuracy/loss: 0.8130/0.5854
[validation] accuracy/loss: 0.8140/0.5849
[validation] accuracy/loss: 0.8143/0.5844
[validation] accuracy/loss: 0.8133/0.5840
[validation] accuracy/loss: 0.8140/0.5835
[validation] accuracy/loss: 0.8165/0.5831
[validation] accuracy/loss: 0.8169/0.5827
[validation] accuracy/loss: 0.8157/0.5824
[validation] accuracy/loss: 0.8185/0.5822
[validation] accuracy/loss: 0.8189/0.5820
[validation] accuracy/loss: 0.8174/0.5819
[validation] accuracy/loss: 0.8191/0.5817
[validation] accuracy/loss: 0.8198/0.5815
[validation] accuracy/loss: 0.8186/0.5814
[validation] accuracy/loss: 0.8202/0.5812
[validation] accuracy/loss: 0.8200/0.5810
[validation] accuracy/loss: 0.8191/0.5809
[validation] accuracy/loss: 0.8207/0.5807
[validation] accuracy/loss: 0.8198/0.5806
[validation] accuracy/loss: 0.8196/0.5805
[validation] accuracy/loss: 0.8213/0.5804
[validation] accuracy/loss: 0.8198/0.5802
[validation] accuracy/loss: 0.8202/0.5800
[validation] accuracy/loss: 0.8220/0.5799
[validation] accuracy/loss: 0.8204/0.5798
[validation] accuracy/loss: 0.8235/0.5797
[validation] accuracy/loss: 0.8219/0.5794
[validation] accuracy/loss: 0.8241/0.5792
[validation] accuracy/loss: 0.8238/0.5791
[validation] accuracy/loss: 0.8248/0.5789
[validation] accuracy/loss: 0.8250/0.5787
[validation] accuracy/loss: 0.8273/0.5786
[validation] accuracy/loss: 0.8191/0.5793
[validation] accuracy/loss: 0.8311/0.5812
[validation] accuracy/loss: 0.8311/0.5816
[validation] accuracy/loss: 0.8311/0.5794
[validation] accuracy/loss: 0.8211/0.5788
[validation] accuracy/loss: 0.8121/0.5805
[validation] accuracy/loss: 0.8299/0.5782
[validation] accuracy/loss: 0.8339/0.5797
[validation] accuracy/loss: 0.8339/0.5799
[validation] accuracy/loss: 0.8332/0.5787
[validation] accuracy/loss: 0.8291/0.5777
[validation] accuracy/loss: 0.8209/0.5783
[validation] accuracy/loss: 0.8206/0.5784
[validation] accuracy/loss: 0.8290/0.5772
[validation] accuracy/loss: 0.8332/0.5775
[validation] accuracy/loss: 0.8350/0.5779
[validation] accuracy/loss: 0.8349/0.5775
[validation] accuracy/loss: 0.8327/0.5767
[validation] accuracy/loss: 0.8278/0.5768
[validation] accuracy/loss: 0.8258/0.5771
[validation] accuracy/loss: 0.8307/0.5765
[validation] accuracy/loss: 0.8343/0.5764
[validation] accuracy/loss: 0.8363/0.5766
[validation] accuracy/loss: 0.8364/0.5764
[validation] accuracy/loss: 0.8348/0.5759
[validation] accuracy/loss: 0.8310/0.5761
[validation] accuracy/loss: 0.8312/0.5760
[validation] accuracy/loss: 0.8349/0.5757
[validation] accuracy/loss: 0.8370/0.5758
[validation] accuracy/loss: 0.8375/0.5757
[validation] accuracy/loss: 0.8356/0.5755
[validation] accuracy/loss: 0.8333/0.5755
[validation] accuracy/loss: 0.8334/0.5754
[validation] accuracy/loss: 0.8361/0.5752
[validation] accuracy/loss: 0.8378/0.5753
[validation] accuracy/loss: 0.8371/0.5751
[validation] accuracy/loss: 0.8358/0.5750
[validation] accuracy/loss: 0.8352/0.5750
[validation] accuracy/loss: 0.8364/0.5749
[validation] accuracy/loss: 0.8378/0.5749
[validation] accuracy/loss: 0.8372/0.5747
[validation] accuracy/loss: 0.8361/0.5747
[validation] accuracy/loss: 0.8367/0.5746
[validation] accuracy/loss: 0.8385/0.5745
[validation] accuracy/loss: 0.8383/0.5744
[validation] accuracy/loss: 0.8373/0.5744
[validation] accuracy/loss: 0.8381/0.5743
[validation] accuracy/loss: 0.8393/0.5742
[validation] accuracy/loss: 0.8384/0.5741
[validation] accuracy/loss: 0.8378/0.5741
[validation] accuracy/loss: 0.8394/0.5740
[validation] accuracy/loss: 0.8397/0.5739
[validation] accuracy/loss: 0.8386/0.5738
[validation] accuracy/loss: 0.8394/0.5738
[validation] accuracy/loss: 0.8402/0.5737
[validation] accuracy/loss: 0.8392/0.5736
[validation] accuracy/loss: 0.8407/0.5735
[validation] accuracy/loss: 0.8401/0.5734
[validation] accuracy/loss: 0.8407/0.5733
[validation] accuracy/loss: 0.8406/0.5732
[validation] accuracy/loss: 0.8417/0.5732
[validation] accuracy/loss: 0.8388/0.5733
[validation] accuracy/loss: 0.8430/0.5739
[validation] accuracy/loss: 0.8332/0.5744
[validation] accuracy/loss: 0.8351/0.5745
[validation] accuracy/loss: 0.8370/0.5742
[validation] accuracy/loss: 0.8367/0.5739
[validation] accuracy/loss: 0.8384/0.5738
[validation] accuracy/loss: 0.8406/0.5738
[validation] accuracy/loss: 0.8399/0.5734
[validation] accuracy/loss: 0.8383/0.5734
[validation] accuracy/loss: 0.8406/0.5732
[validation] accuracy/loss: 0.8421/0.5731
[validation] accuracy/loss: 0.8410/0.5730
[validation] accuracy/loss: 0.8416/0.5729
[validation] accuracy/loss: 0.8422/0.5728
[validation] accuracy/loss: 0.8413/0.5727
[validation] accuracy/loss: 0.8418/0.5725
[validation] accuracy/loss: 0.8440/0.5725
[validation] accuracy/loss: 0.8431/0.5723
[validation] accuracy/loss: 0.8442/0.5721
[validation] accuracy/loss: 0.8448/0.5720
[validation] accuracy/loss: 0.8431/0.5719
[validation] accuracy/loss: 0.8453/0.5717
[validation] accuracy/loss: 0.8445/0.5716
[validation] accuracy/loss: 0.8457/0.5715
[validation] accuracy/loss: 0.8446/0.5715
[validation] accuracy/loss: 0.8464/0.5714
[validation] accuracy/loss: 0.8419/0.5717
[validation] accuracy/loss: 0.8472/0.5727
[validation] accuracy/loss: 0.8378/0.5728
[validation] accuracy/loss: 0.8369/0.5731
[validation] accuracy/loss: 0.8413/0.5727
[validation] accuracy/loss: 0.8431/0.5720
[validation] accuracy/loss: 0.8409/0.5726
[validation] accuracy/loss: 0.8462/0.5720
[validation] accuracy/loss: 0.8439/0.5717
[validation] accuracy/loss: 0.8411/0.5720
[validation] accuracy/loss: 0.8453/0.5713
[validation] accuracy/loss: 0.8469/0.5718
[validation] accuracy/loss: 0.8435/0.5714
[validation] accuracy/loss: 0.8454/0.5711
[validation] accuracy/loss: 0.8461/0.5711
[validation] accuracy/loss: 0.8448/0.5710
[validation] accuracy/loss: 0.8477/0.5709
[validation] accuracy/loss: 0.8467/0.5708
[validation] accuracy/loss: 0.8470/0.5706
[validation] accuracy/loss: 0.8481/0.5706
[validation] accuracy/loss: 0.8463/0.5704
[validation] accuracy/loss: 0.8480/0.5707
[validation] accuracy/loss: 0.8475/0.5703
[validation] accuracy/loss: 0.8477/0.5705
[validation] accuracy/loss: 0.8490/0.5704
[validation] accuracy/loss: 0.8470/0.5700
[validation] accuracy/loss: 0.8485/0.5706
[validation] accuracy/loss: 0.8462/0.5704
[validation] accuracy/loss: 0.8499/0.5700
[validation] accuracy/loss: 0.8472/0.5703
[validation] accuracy/loss: 0.8495/0.5697
[validation] accuracy/loss: 0.8507/0.5697
[validation] accuracy/loss: 0.8455/0.5702
[validation] accuracy/loss: 0.8503/0.5707
[validation] accuracy/loss: 0.8404/0.5715
[validation] accuracy/loss: 0.8451/0.5710
[validation] accuracy/loss: 0.8486/0.5703
[validation] accuracy/loss: 0.8442/0.5710
[validation] accuracy/loss: 0.8505/0.5698
[validation] accuracy/loss: 0.8469/0.5706
[validation] accuracy/loss: 0.8457/0.5706
[validation] accuracy/loss: 0.8503/0.5694
[validation] accuracy/loss: 0.8501/0.5706
[validation] accuracy/loss: 0.8445/0.5704
[validation] accuracy/loss: 0.8462/0.5708
[validation] accuracy/loss: 0.8478/0.5702
[validation] accuracy/loss: 0.8445/0.5703
[validation] accuracy/loss: 0.8507/0.5702
[validation] accuracy/loss: 0.8500/0.5695
[validation] accuracy/loss: 0.8476/0.5698
[validation] accuracy/loss: 0.8501/0.5694
[validation] accuracy/loss: 0.8504/0.5692
[validation] accuracy/loss: 0.8508/0.5691
[validation] accuracy/loss: 0.8508/0.5691
[validation] accuracy/loss: 0.8518/0.5688
[validation] accuracy/loss: 0.8506/0.5690
[validation] accuracy/loss: 0.8520/0.5690
[validation] accuracy/loss: 0.8479/0.5692
[validation] accuracy/loss: 0.8522/0.5695
[validation] accuracy/loss: 0.8473/0.5694
[validation] accuracy/loss: 0.8529/0.5687
[validation] accuracy/loss: 0.8531/0.5684
[validation] accuracy/loss: 0.8503/0.5686
[validation] accuracy/loss: 0.8525/0.5696
[validation] accuracy/loss: 0.8435/0.5703
[validation] accuracy/loss: 0.8494/0.5695
[validation] accuracy/loss: 0.8518/0.5690
[validation] accuracy/loss: 0.8476/0.5698
[validation] accuracy/loss: 0.8525/0.5687
[validation] accuracy/loss: 0.8507/0.5691
[validation] accuracy/loss: 0.8511/0.5686
[validation] accuracy/loss: 0.8526/0.5692
[validation] accuracy/loss: 0.8469/0.5694
[validation] accuracy/loss: 0.8495/0.5698
[validation] accuracy/loss: 0.8510/0.5686
[validation] accuracy/loss: 0.8505/0.5692
[validation] accuracy/loss: 0.8511/0.5687
[validation] accuracy/loss: 0.8515/0.5691
[validation] accuracy/loss: 0.8517/0.5684
[validation] accuracy/loss: 0.8506/0.5689
[validation] accuracy/loss: 0.8525/0.5686
[validation] accuracy/loss: 0.8504/0.5686
[validation] accuracy/loss: 0.8524/0.5681
[validation] accuracy/loss: 0.8538/0.5682
[validation] accuracy/loss: 0.8504/0.5684
[validation] accuracy/loss: 0.8533/0.5687
[validation] accuracy/loss: 0.8527/0.5678
[validation] accuracy/loss: 0.8536/0.5678
[validation] accuracy/loss: 0.8546/0.5678
[validation] accuracy/loss: 0.8502/0.5684
[validation] accuracy/loss: 0.8545/0.5687
[validation] accuracy/loss: 0.8491/0.5686
[validation] accuracy/loss: 0.8549/0.5677
[validation] accuracy/loss: 0.8542/0.5675
[validation] accuracy/loss: 0.8523/0.5679
[validation] accuracy/loss: 0.8553/0.5677
[validation] accuracy/loss: 0.8522/0.5679
[validation] accuracy/loss: 0.8549/0.5675
[validation] accuracy/loss: 0.8538/0.5677
[validation] accuracy/loss: 0.8539/0.5675
[validation] accuracy/loss: 0.8550/0.5673
[validation] accuracy/loss: 0.8533/0.5677
[validation] accuracy/loss: 0.8536/0.5678
[validation] accuracy/loss: 0.8546/0.5673
[validation] accuracy/loss: 0.8540/0.5680
[validation] accuracy/loss: 0.8503/0.5682
[validation] accuracy/loss: 0.8517/0.5697
[validation] accuracy/loss: 0.8543/0.5676
[validation] accuracy/loss: 0.8467/0.5694
[validation] accuracy/loss: 0.8538/0.5685
[validation] accuracy/loss: 0.8511/0.5689
[validation] accuracy/loss: 0.8483/0.5689
[validation] accuracy/loss: 0.8538/0.5683
[validation] accuracy/loss: 0.8530/0.5675
[validation] accuracy/loss: 0.8540/0.5678
[validation] accuracy/loss: 0.8561/0.5671
[validation] accuracy/loss: 0.8517/0.5678
[validation] accuracy/loss: 0.8551/0.5680
[validation] accuracy/loss: 0.8532/0.5676
[validation] accuracy/loss: 0.8517/0.5679
[validation] accuracy/loss: 0.8550/0.5680
[validation] accuracy/loss: 0.8504/0.5680
[validation] accuracy/loss: 0.8553/0.5677
[validation] accuracy/loss: 0.8554/0.5672
[validation] accuracy/loss: 0.8503/0.5680
[validation] accuracy/loss: 0.8548/0.5685
[validation] accuracy/loss: 0.8532/0.5675
[validation] accuracy/loss: 0.8536/0.5672
[validation] accuracy/loss: 0.8553/0.5681
[validation] accuracy/loss: 0.8511/0.5679
[validation] accuracy/loss: 0.8538/0.5675
[validation] accuracy/loss: 0.8565/0.5673
[validation] accuracy/loss: 0.8500/0.5681
[validation] accuracy/loss: 0.8554/0.5670
[validation] accuracy/loss: 0.8559/0.5675
[validation] accuracy/loss: 0.8516/0.5677
[validation] accuracy/loss: 0.8552/0.5674
[validation] accuracy/loss: 0.8550/0.5675
[validation] accuracy/loss: 0.8515/0.5677
[validation] accuracy/loss: 0.8553/0.5672
[validation] accuracy/loss: 0.8567/0.5666
[validation] accuracy/loss: 0.8540/0.5670
[validation] accuracy/loss: 0.8564/0.5668
[validation] accuracy/loss: 0.8570/0.5663
[validation] accuracy/loss: 0.8572/0.5664
[validation] accuracy/loss: 0.8549/0.5667
[validation] accuracy/loss: 0.8567/0.5663
[validation] accuracy/loss: 0.8568/0.5666
[validation] accuracy/loss: 0.8566/0.5663
[validation] accuracy/loss: 0.8573/0.5662
[validation] accuracy/loss: 0.8569/0.5663
[validation] accuracy/loss: 0.8541/0.5673
[validation] accuracy/loss: 0.8527/0.5676
[validation] accuracy/loss: 0.8570/0.5668
[validation] accuracy/loss: 0.8532/0.5673
[validation] accuracy/loss: 0.8568/0.5662
[validation] accuracy/loss: 0.8566/0.5672
[validation] accuracy/loss: 0.8520/0.5675
[validation] accuracy/loss: 0.8538/0.5679
[validation] accuracy/loss: 0.8557/0.5669
[validation] accuracy/loss: 0.8565/0.5666
[validation] accuracy/loss: 0.8527/0.5682
[validation] accuracy/loss: 0.8558/0.5671
[validation] accuracy/loss: 0.8520/0.5686
[validation] accuracy/loss: 0.8507/0.5681
[validation] accuracy/loss: 0.8507/0.5688
[validation] accuracy/loss: 0.8528/0.5677
[validation] accuracy/loss: 0.8528/0.5677
[validation] accuracy/loss: 0.8557/0.5672
[validation] accuracy/loss: 0.8543/0.5674
[validation] accuracy/loss: 0.8536/0.5672
[validation] accuracy/loss: 0.8563/0.5669
[validation] accuracy/loss: 0.8567/0.5664
[validation] accuracy/loss: 0.8558/0.5667
[validation] accuracy/loss: 0.8570/0.5661
[validation] accuracy/loss: 0.8568/0.5663
[validation] accuracy/loss: 0.8563/0.5664
[validation] accuracy/loss: 0.8569/0.5662
[validation] accuracy/loss: 0.8568/0.5663
[validation] accuracy/loss: 0.8582/0.5658
[validation] accuracy/loss: 0.8569/0.5659
[validation] accuracy/loss: 0.8589/0.5659
[validation] accuracy/loss: 0.8555/0.5662
[validation] accuracy/loss: 0.8587/0.5663
[validation] accuracy/loss: 0.8537/0.5667
[validation] accuracy/loss: 0.8588/0.5659
[validation] accuracy/loss: 0.8568/0.5659
[validation] accuracy/loss: 0.8586/0.5655
[validation] accuracy/loss: 0.8594/0.5657
[validation] accuracy/loss: 0.8516/0.5673
[validation] accuracy/loss: 0.8557/0.5674
[validation] accuracy/loss: 0.8584/0.5661
[validation] accuracy/loss: 0.8503/0.5677
[validation] accuracy/loss: 0.8581/0.5661
[validation] accuracy/loss: 0.8589/0.5658
[validation] accuracy/loss: 0.8538/0.5666
[validation] accuracy/loss: 0.8590/0.5658
[validation] accuracy/loss: 0.8578/0.5659
[validation] accuracy/loss: 0.8573/0.5656
[validation] accuracy/loss: 0.8582/0.5666
[validation] accuracy/loss: 0.8529/0.5669
[validation] accuracy/loss: 0.8561/0.5666
[validation] accuracy/loss: 0.8583/0.5659
[validation] accuracy/loss: 0.8563/0.5660
[validation] accuracy/loss: 0.8580/0.5664
[validation] accuracy/loss: 0.8584/0.5655
[validation] accuracy/loss: 0.8578/0.5658
[validation] accuracy/loss: 0.8573/0.5664
[validation] accuracy/loss: 0.8574/0.5657
[validation] accuracy/loss: 0.8583/0.5660
[validation] accuracy/loss: 0.8571/0.5657
[validation] accuracy/loss: 0.8586/0.5659
[validation] accuracy/loss: 0.8563/0.5663
[validation] accuracy/loss: 0.8588/0.5654
[validation] accuracy/loss: 0.8584/0.5657
[validation] accuracy/loss: 0.8569/0.5656
[validation] accuracy/loss: 0.8586/0.5666
[validation] accuracy/loss: 0.8537/0.5667
[validation] accuracy/loss: 0.8567/0.5662
[validation] accuracy/loss: 0.8584/0.5656
[validation] accuracy/loss: 0.8570/0.5658
[validation] accuracy/loss: 0.8590/0.5655
[validation] accuracy/loss: 0.8581/0.5656
[validation] accuracy/loss: 0.8586/0.5655
[validation] accuracy/loss: 0.8604/0.5651
[validation] accuracy/loss: 0.8559/0.5658
[validation] accuracy/loss: 0.8593/0.5657
[validation] accuracy/loss: 0.8582/0.5654
[validation] accuracy/loss: 0.8598/0.5649
[validation] accuracy/loss: 0.8598/0.5655
[validation] accuracy/loss: 0.8570/0.5655
[validation] accuracy/loss: 0.8599/0.5655
[validation] accuracy/loss: 0.8565/0.5655
[validation] accuracy/loss: 0.8604/0.5651
[validation] accuracy/loss: 0.8595/0.5653
[validation] accuracy/loss: 0.8584/0.5653
[validation] accuracy/loss: 0.8602/0.5651
[validation] accuracy/loss: 0.8564/0.5659
[validation] accuracy/loss: 0.8594/0.5649
[validation] accuracy/loss: 0.8596/0.5654
[validation] accuracy/loss: 0.8540/0.5665
[validation] accuracy/loss: 0.8599/0.5656
[validation] accuracy/loss: 0.8580/0.5659
[validation] accuracy/loss: 0.8548/0.5660
[validation] accuracy/loss: 0.8591/0.5663
  1. 预测结果
pd.DataFrame(
    np.where(
        mynet(
            paddle.to_tensor(
                np.array(
                    pd.read_csv('pretrain/pcatest.csv')
                ).astype('float32')
            )
        ).numpy() > 0.5, 1, 0
    ),
    columns=['win']
).to_csv('submission.csv',index=False)
!zip result/submission.zip submission.csv
!rm submission.csv
  adding: submission.csv (deflated 90%)

[validation] accuracy/loss: 0.8599/0.5656
[validation] accuracy/loss: 0.8580/0.5659
[validation] accuracy/loss: 0.8548/0.5660
[validation] accuracy/loss: 0.8591/0.5663

  1. 预测结果
pd.DataFrame(
    np.where(
        mynet(
            paddle.to_tensor(
                np.array(
                    pd.read_csv('pretrain/pcatest.csv')
                ).astype('float32')
            )
        ).numpy() > 0.5, 1, 0
    ),
    columns=['win']
).to_csv('submission.csv',index=False)
!zip result/submission.zip submission.csv
!rm submission.csv
  adding: submission.csv (deflated 90%)

项目总结

事实上项目仍然有很多能够改进的地方

  • 模型可以利用Network-in-Network进一步优化网络结构

  • 尽管180000条样本可以一次性训练,修改训练批数仍有可能进步

  • 项目并没有做明确的训练集/测试集划分,可以尝试采用交叉验证来减少过拟合

  • 0
    点赞
  • 32
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值