★★★ 本文源自AlStudio社区精品项目,【点击此处】查看更多精品内容 >>>
飞桨学习赛:英雄联盟大师预测2023年2月85.365分方案
比赛链接 - 飞桨学习赛:英雄联盟大师预测
赛题介绍
本赛题属于典型的分类问题,以英雄联盟手游为背景,要求选手根据英雄联盟玩家的实时游戏数据,预测玩家在本局游戏中的输赢情况。
数据说明
数据集中每一行为一个玩家的游戏数据,数据字段如下所示:
- id:玩家记录id
- win:是否胜利,标签变量
- kills:击杀次数
- deaths:死亡次数
- assists:助攻次数
- largestkillingspree:最大 killing spree(游戏术语,意味大杀特杀。当你连续杀死三个对方英雄而中途没有死亡时)
- largestmultikill:最大mult ikill(游戏术语,短时间内多重击杀)
- longesttimespentliving:最长存活时间
- doublekills:doublekills次数
- triplekills:doublekills次数
- quadrakills:quadrakills次数
- pentakills:pentakills次数
- totdmgdealt:总伤害
- magicdmgdealt:魔法伤害
- physicaldmgdealt:物理伤害
- truedmgdealt:真实伤害
- largestcrit:最大暴击伤害
- totdmgtochamp:对对方玩家的伤害
- magicdmgtochamp:对对方玩家的魔法伤害
- physdmgtochamp:对对方玩家的物理伤害
- truedmgtochamp:对对方玩家的真实伤害
- totheal:治疗量
- totunitshealed:痊愈的总单位
- dmgtoturrets:对炮塔的伤害
- timecc:法控时间
- totdmgtaken:承受的伤害
- magicdmgtaken:承受的魔法伤害
- physdmgtaken:承受的物理伤害
- truedmgtaken:承受的真实伤害
- wardsplaced:侦查守卫放置次数
- wardskilled:侦查守卫摧毁次数
- firstblood:是否为firstblood
测试集中label字段win为空,需要选手预测。
比赛难点
本次比赛难点主要在于数据处理和模型选择上
- 比赛数据既没有缺失值也没有特别明显的异常值,数据分布比较正常,特征工程上找不到很多能增长准确率的地方
- 事实上就算不做任何数据处理直接用随机森例、LightGBM和xgboost模型都能够非常轻松的拿到82-84分神经网络模型也并不能够有很大程度的进步甚至不如树模型
项目亮点
-
采用遗传算法构造了更多特征并利用PCA算法进行特征降维
-
利用Dense Connetion思想设计网络结构防止梯度消失
代码仓库
https://github.com/ZhangzrJerry/paddle-lolmp
- 解压数据
!unzip -d data/ data/data137276/test.csv.zip
!unzip -d data/ data/data137276/train.csv.zip
!pip install -r requirements.txt -i https://pypi.tuna.tsinghua.edu.cn/simple/
!mkdir result pretrain
Archive: data/data137276/test.csv.zip
inflating: data/test.csv
Archive: data/data137276/train.csv.zip
inflating: data/train.csv
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple/
Collecting category_encoders
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/8b/9f/09d149aab1296254fe83e34aaddc59ade820acb15d529773d753df3384bf/category_encoders-2.6.0-py2.py3-none-any.whl (81 kB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m81.2/81.2 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hRequirement already satisfied: seaborn in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.10.0)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (1.19.5)
Requirement already satisfied: pandas in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (1.1.5)
Requirement already satisfied: sklearn in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 5)) (0.0)
Collecting gplearn
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/cc/b0/063b2ddfd9258c4f43abd3b5d13ca94b53e9479a3f21df18fafe3948b67d/gplearn-0.4.2-py3-none-any.whl (25 kB)
Requirement already satisfied: IPython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from -r requirements.txt (line 7)) (7.34.0)
Collecting pydotplus
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/60/bf/62567830b700d9f6930e9ab6831d6ba256f7b0b730acb37278b0ccdffacf/pydotplus-2.0.2.tar.gz (278 kB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m278.7/278.7 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25h Preparing metadata (setup.py) ... [?25ldone
[?25hRequirement already satisfied: scikit-learn>=0.20.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from category_encoders->-r requirements.txt (line 1)) (0.24.2)
Collecting statsmodels>=0.9.0
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/91/8e/062b268b8e6d19382cbf2f97ac0384285511790718ce90bbfb1eb5e44b07/statsmodels-0.13.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (9.9 MB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hRequirement already satisfied: scipy>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from category_encoders->-r requirements.txt (line 1)) (1.6.3)
Collecting patsy>=0.5.1
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/2a/e4/b3263b0e353f2be7b14f044d57874490c9cef1798a435f038683acea5c98/patsy-0.5.3-py2.py3-none-any.whl (233 kB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m233.8/233.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hRequirement already satisfied: matplotlib>=2.1.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seaborn->-r requirements.txt (line 2)) (2.2.3)
Requirement already satisfied: python-dateutil>=2.7.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas->-r requirements.txt (line 4)) (2.8.2)
Requirement already satisfied: pytz>=2017.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pandas->-r requirements.txt (line 4)) (2019.3)
Collecting scikit-learn>=0.20.0
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/bd/05/e561bc99a615b5c099c7a9355409e5e57c525a108f1c2e156abb005b90a6/scikit_learn-1.0.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (24.8 MB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.8/24.8 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hCollecting joblib>=1.0.0
Downloading https://pypi.tuna.tsinghua.edu.cn/packages/91/d4/3b4c8e5a30604df4c7518c562d4bf0502f2fa29221459226e140cf846512/joblib-1.2.0-py3-none-any.whl (297 kB)
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m298.0/298.0 kB[0m [31m682.1 kB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hRequirement already satisfied: decorator in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (4.4.2)
Requirement already satisfied: pickleshare in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.7.5)
Requirement already satisfied: jedi>=0.16 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.17.2)
Requirement already satisfied: pexpect>4.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (4.7.0)
Requirement already satisfied: traitlets>=4.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (5.4.0)
Requirement already satisfied: setuptools>=18.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (56.2.0)
Requirement already satisfied: backcall in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.1.0)
Requirement already satisfied: pygments in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (2.13.0)
Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (2.0.10)
Requirement already satisfied: matplotlib-inline in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from IPython->-r requirements.txt (line 7)) (0.1.6)
Requirement already satisfied: pyparsing>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pydotplus->-r requirements.txt (line 8)) (3.0.9)
Requirement already satisfied: parso<0.8.0,>=0.7.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from jedi>=0.16->IPython->-r requirements.txt (line 7)) (0.7.1)
Requirement already satisfied: six>=1.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (1.16.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib>=2.1.2->seaborn->-r requirements.txt (line 2)) (1.1.0)
Requirement already satisfied: ptyprocess>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pexpect>4.3->IPython->-r requirements.txt (line 7)) (0.7.0)
Requirement already satisfied: wcwidth in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->IPython->-r requirements.txt (line 7)) (0.1.7)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.20.0->category_encoders->-r requirements.txt (line 1)) (2.1.0)
Requirement already satisfied: packaging>=21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from statsmodels>=0.9.0->category_encoders->-r requirements.txt (line 1)) (21.3)
Building wheels for collected packages: pydotplus
Building wheel for pydotplus (setup.py) ... [?25ldone
[?25h Created wheel for pydotplus: filename=pydotplus-2.0.2-py3-none-any.whl size=24566 sha256=6ae026c23d9c6a62321edd556da9ca4e035a4bc23b4c36f5cc1abb48cd91873b
Stored in directory: /home/aistudio/.cache/pip/wheels/5c/0c/20/fd91edec2a19961da82914c46e465380335e60bc30253db979
Successfully built pydotplus
Installing collected packages: pydotplus, patsy, joblib, scikit-learn, statsmodels, gplearn, category_encoders
Attempting uninstall: joblib
Found existing installation: joblib 0.14.1
Uninstalling joblib-0.14.1:
Successfully uninstalled joblib-0.14.1
Attempting uninstall: scikit-learn
Found existing installation: scikit-learn 0.24.2
Uninstalling scikit-learn-0.24.2:
Successfully uninstalled scikit-learn-0.24.2
Successfully installed category_encoders-2.6.0 gplearn-0.4.2 joblib-1.2.0 patsy-0.5.3 pydotplus-2.0.2 scikit-learn-1.0.2 statsmodels-0.13.5
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip available: [0m[31;49m22.1.2[0m[39;49m -> [0m[32;49m23.0[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
import numpy as np
import pandas as pd
x_train = pd.read_csv('data/train.csv').drop('win',axis=1)
y_train = pd.read_csv('data/train.csv')['win']
x_test = pd.read_csv('data/test.csv')
# 拼接后方便特征工程
feature = pd.concat([x_train, x_test])
- 数据可视化
import matplotlib.pyplot as plt
import seaborn as sns
# 查看数据的信息
print(feature.describe())
print(feature.info())
# 查看数据分布特征
sns.FacetGrid(pd.melt(feature), col="variable", col_wrap=4, sharex=False, sharey=False).map(sns.distplot, "value")
# 查看数据相关性
sns.set_context({"figure.figsize":(8,8)})
sns.heatmap(data=feature.corr(),square=True,cmap='RdBu_r')
id kills deaths assists \
count 200000.000000 200000.000000 200000.000000 200000.000000
mean 99999.500000 5.798545 5.810190 8.322870
std 57735.171256 4.605316 3.263815 5.933893
min 0.000000 0.000000 0.000000 0.000000
25% 49999.750000 2.000000 3.000000 4.000000
50% 99999.500000 5.000000 6.000000 7.000000
75% 149999.250000 8.000000 8.000000 12.000000
max 199999.000000 39.000000 23.000000 52.000000
largestkillingspree largestmultikill longesttimespentliving \
count 200000.000000 200000.000000 200000.000000
mean 2.671450 1.332095 630.531655
std 2.537784 0.758037 311.568408
min 0.000000 0.000000 0.000000
25% 0.000000 1.000000 433.000000
50% 2.000000 1.000000 590.000000
75% 4.000000 2.000000 792.000000
max 31.000000 5.000000 3038.000000
doublekills triplekills quadrakills ... totunitshealed \
count 200000.000000 200000.000000 200000.000000 ... 200000.000000
mean 0.540265 0.073220 0.010235 ... 2.253135
std 0.924831 0.295887 0.104548 ... 2.481890
min 0.000000 0.000000 0.000000 ... 0.000000
25% 0.000000 0.000000 0.000000 ... 1.000000
50% 0.000000 0.000000 0.000000 ... 1.000000
75% 1.000000 0.000000 0.000000 ... 3.000000
max 11.000000 7.000000 4.000000 ... 98.000000
dmgtoturrets timecc totdmgtaken magicdmgtaken physdmgtaken \
count 200000.000000 200000.0 200000.000000 200000.000000 200000.000000
mean 2138.209810 0.0 23226.733180 8136.551010 14039.533350
std 2934.306106 0.0 11873.669826 5161.055339 7754.110833
min 0.000000 0.0 0.000000 0.000000 0.000000
25% 0.000000 0.0 15264.000000 4521.000000 8627.000000
50% 986.000000 0.0 21531.000000 7246.000000 12803.000000
75% 3222.250000 0.0 29465.250000 10739.000000 18205.000000
max 55083.000000 0.0 118130.000000 71631.000000 73172.000000
truedmgtaken wardsplaced wardskilled firstblood
count 200000.000000 200000.000000 200000.000000 200000.000000
mean 1049.892170 11.508290 1.782860 0.100380
std 1266.146212 7.539761 2.226049 0.300507
min 0.000000 0.000000 0.000000 0.000000
25% 274.000000 7.000000 0.000000 0.000000
50% 656.000000 10.000000 1.000000 0.000000
75% 1352.000000 14.000000 3.000000 0.000000
max 25140.000000 322.000000 48.000000 1.000000
[8 rows x 31 columns]
<class 'pandas.core.frame.DataFrame'>
Int64Index: 200000 entries, 0 to 19999
Data columns (total 31 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 id 200000 non-null int64
1 kills 200000 non-null int64
2 deaths 200000 non-null int64
3 assists 200000 non-null int64
4 largestkillingspree 200000 non-null int64
5 largestmultikill 200000 non-null int64
6 longesttimespentliving 200000 non-null int64
7 doublekills 200000 non-null int64
8 triplekills 200000 non-null int64
9 quadrakills 200000 non-null int64
10 pentakills 200000 non-null int64
11 totdmgdealt 200000 non-null int64
12 magicdmgdealt 200000 non-null int64
13 physicaldmgdealt 200000 non-null int64
14 truedmgdealt 200000 non-null int64
15 largestcrit 200000 non-null int64
16 totdmgtochamp 200000 non-null int64
17 magicdmgtochamp 200000 non-null int64
18 physdmgtochamp 200000 non-null int64
19 truedmgtochamp 200000 non-null int64
20 totheal 200000 non-null int64
21 totunitshealed 200000 non-null int64
22 dmgtoturrets 200000 non-null int64
23 timecc 200000 non-null int64
24 totdmgtaken 200000 non-null int64
25 magicdmgtaken 200000 non-null int64
26 physdmgtaken 200000 non-null int64
27 truedmgtaken 200000 non-null int64
28 wardsplaced 200000 non-null int64
29 wardskilled 200000 non-null int64
30 firstblood 200000 non-null int64
dtypes: int64(31)
memory usage: 48.8 MB
None
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/seaborn/distributions.py:288: UserWarning: Data must have variance to compute a kernel density estimate.
warnings.warn(msg, UserWarning)
<matplotlib.axes._subplots.AxesSubplot at 0x7f6b5a36cf50>
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-sONywZGi-1677308129213)(main_files/main_5_3.png)]
- 数据预处理
from category_encoders import *
from sklearn.preprocessing import MinMaxScaler
# 对离散变量顺序编码
columns= ['kills','deaths','assists', 'largestkillingspree',
'largestmultikill', 'longesttimespentliving', 'doublekills',
'triplekills', 'quadrakills', 'pentakills','firstblood']
m_feature = OrdinalEncoder(cols=columns).fit(feature).transform(feature)
m_feature.columns=[
'id', 'akills', 'adeaths', 'aassists', 'alargestkillingspree',
'alargestmultikill', 'alongesttimespentliving', 'adoublekills',
'atriplekills', 'aquadrakills', 'apentakills', 'totdmgdealt',
'magicdmgdealt', 'physicaldmgdealt', 'truedmgdealt', 'largestcrit',
'totdmgtochamp', 'magicdmgtochamp', 'physdmgtochamp', 'truedmgtochamp',
'totheal', 'totunitshealed', 'dmgtoturrets', 'timecc', 'totdmgtaken',
'magicdmgtaken', 'physdmgtaken', 'truedmgtaken', 'wardsplaced',
'wardskilled', 'afirstblood'
]
m_feature=pd.concat([m_feature,feature[columns]],axis=1).drop(['timecc','id'],axis=1)
# 对特征归一化
sc = MinMaxScaler().fit(m_feature)
feature = sc.transform(m_feature)
pd.DataFrame(feature).to_csv('pretrain/feature.csv',index=False)
pd.DataFrame(feature[:180000]).to_csv('pretrain/train_feature.csv',index=False)
pd.DataFrame(feature[180000:]).to_csv('pretrain/test_feature.csv',index=False)
- 自动挖掘特征
from gplearn.genetic import SymbolicTransformer
from IPython.display import Image
import pydotplus
import numpy as np
import pandas as pd
function_set = ['add', 'sub', 'mul', 'div', 'log', 'sqrt', 'abs', 'neg', 'max', 'min']
st = SymbolicTransformer(
generations=20,
population_size=1000,
hall_of_fame=100,
n_components=100,
function_set=function_set,
parsimony_coefficient=0.0005,
max_samples=0.9,
verbose=1,
random_state=0,
n_jobs=3
)
st.fit(np.array(pd.read_csv('pretrain/train_feature.csv')),np.array(pd.read_csv('data/train.csv')['win']))
graph = st._best_programs[0].export_graphviz()
graph = pydotplus.graphviz.graph_from_dot_data(graph)
display(Image(graph.create_png()))
pd.DataFrame(
np.concatenate(
[
st.transform(
np.array(
pd.read_csv('pretrain/train_feature.csv')
)
),
np.array(
pd.read_csv('pretrain/train_feature.csv')
)
],
axis=1
)
).to_csv('pretrain/gptrain.csv',index=False)
pd.DataFrame(
np.concatenate(
[
st.transform(
np.array(
pd.read_csv('pretrain/test_feature.csv')
)
),
np.array(
pd.read_csv('pretrain/test_feature.csv')
)
],
axis=1
)
).to_csv('pretrain/gptest.csv',index=False)
| Population Average | Best Individual |
---- ------------------------- ------------------------------------------ ----------
Gen Length Fitness Length Fitness OOB Fitness Time Left
0 12.03 0.0967885 4 0.395269 0.413457 2.82m
1 6.91 0.250165 5 0.512841 0.518938 1.67m
2 4.82 0.322645 7 0.600911 0.603908 1.49m
3 5.58 0.365057 12 0.604385 0.606605 1.34m
4 6.48 0.414963 9 0.615257 0.610385 1.35m
5 8.58 0.468763 9 0.628536 0.620473 1.22m
6 11.07 0.508876 15 0.638086 0.641776 1.18m
7 13.61 0.514324 11 0.640324 0.636264 1.12m
8 14.01 0.540671 39 0.656748 0.657258 1.09m
9 15.80 0.544383 39 0.656818 0.656685 1.02m
10 18.48 0.556065 41 0.656622 0.658435 54.90s
11 19.58 0.569885 50 0.66438 0.666361 49.46s
12 20.93 0.58614 39 0.665209 0.665554 46.77s
13 22.09 0.589301 53 0.666591 0.67299 38.94s
14 23.42 0.597276 59 0.666831 0.671305 33.80s
15 25.60 0.600218 65 0.66804 0.668831 28.90s
16 27.89 0.618124 76 0.667599 0.668118 22.03s
17 31.54 0.6234 62 0.668819 0.660909 15.79s
18 33.64 0.631025 55 0.669813 0.669861 7.87s
19 32.85 0.625123 52 0.672489 0.662052 0.00s
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Q5hdKQun-1677308129215)(main_files/main_9_1.png)]
- 特征降维
import pandas as pd
from sklearn.decomposition import PCA
traindata = pd.read_csv('pretrain/gptrain.csv')
testdata = pd.read_csv('pretrain/gptest.csv')
transfer = PCA(n_components=64)
pd.DataFrame(
transfer.fit_transform(traindata)
).to_csv('pretrain/pcatrain.csv',index=False)
pd.DataFrame(
transfer.transform(testdata)
).to_csv('pretrain/pcatest.csv',index=False)
- 模型搭建
from paddle import nn
import paddle
import paddle.nn.functional as F
class DenseBlock(nn.Layer):
def __init__(self):
super(DenseBlock, self).__init__()
self.fc1 = nn.Linear(96, 64)
self.fc2 = nn.Linear(64, 32)
pass
def forward(self, input, dense):
x = self.fc1(input)
x = self.fc2(x)
x = F.relu(x)
x = paddle.concat([x, dense],axis=1)
return x
pass
class MyNet(nn.Layer):
def __init__(self):
super(MyNet, self).__init__()
self.fc0 = nn.Linear(64, 96)
self.bl1 = DenseBlock()
self.bl2 = DenseBlock()
self.bl3 = DenseBlock()
self.bl4 = DenseBlock()
self.bl5 = DenseBlock()
self.bl6 = DenseBlock()
self.bl7 = DenseBlock()
self.bl8 = DenseBlock()
self.bl9 = DenseBlock()
self.fc4 = nn.Linear(96, 32)
self.fc5 = nn.Linear(32, 1)
pass
def forward(self, input):
x = self.fc0(input)
x = self.bl1(x, input)
x = self.bl2(x, input)
x = self.bl3(x, input)
x = self.bl4(x, input)
x = self.bl5(x, input)
x = self.bl6(x, input)
x = self.bl7(x, input)
x = self.bl8(x, input)
x = self.bl9(x, input)
x = self.fc4(x)
x = self.fc5(x)
return F.sigmoid(x)
pass
paddle.summary(MyNet(),(180000,64))
W0216 17:52:57.241905 5934 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 8.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0216 17:52:57.244976 5934 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
-----------------------------------------------------------------------------------
Layer (type) Input Shape Output Shape Param #
===================================================================================
Linear-1 [[180000, 64]] [180000, 96] 6,240
Linear-2 [[180000, 96]] [180000, 64] 6,208
Linear-3 [[180000, 64]] [180000, 32] 2,080
DenseBlock-1 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-4 [[180000, 96]] [180000, 64] 6,208
Linear-5 [[180000, 64]] [180000, 32] 2,080
DenseBlock-2 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-6 [[180000, 96]] [180000, 64] 6,208
Linear-7 [[180000, 64]] [180000, 32] 2,080
DenseBlock-3 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-8 [[180000, 96]] [180000, 64] 6,208
Linear-9 [[180000, 64]] [180000, 32] 2,080
DenseBlock-4 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-10 [[180000, 96]] [180000, 64] 6,208
Linear-11 [[180000, 64]] [180000, 32] 2,080
DenseBlock-5 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-12 [[180000, 96]] [180000, 64] 6,208
Linear-13 [[180000, 64]] [180000, 32] 2,080
DenseBlock-6 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-14 [[180000, 96]] [180000, 64] 6,208
Linear-15 [[180000, 64]] [180000, 32] 2,080
DenseBlock-7 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-16 [[180000, 96]] [180000, 64] 6,208
Linear-17 [[180000, 64]] [180000, 32] 2,080
DenseBlock-8 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-18 [[180000, 96]] [180000, 64] 6,208
Linear-19 [[180000, 64]] [180000, 32] 2,080
DenseBlock-9 [[180000, 96], [180000, 64]] [180000, 96] 0
Linear-20 [[180000, 96]] [180000, 32] 3,104
Linear-21 [[180000, 32]] [180000, 1] 33
===================================================================================
Total params: 83,969
Trainable params: 83,969
Non-trainable params: 0
-----------------------------------------------------------------------------------
Input size (MB): 43.95
Forward/backward pass size (MB): 2550.20
Params size (MB): 0.32
Estimated Total Size (MB): 2594.47
-----------------------------------------------------------------------------------
{'total_params': 83969, 'trainable_params': 83969}
- 模型训练
import paddle.nn.functional as F
import paddle
import pandas as pd
import numpy as np
def train_pm(model, optimizer, feature, label, epoches=1):
# 开启0号GPU训练
paddle.device.set_device('gpu:0')
print('start training ')
model.train()
for epoch in range(epoches):
feature = paddle.to_tensor(feature)
label = paddle.reshape(paddle.to_tensor(label),(-1,1))
# 运行模型前向计算,得到预测值
logits = model(feature)
loss = F.binary_cross_entropy_with_logits(logits, label)
avg_loss = paddle.mean(loss)
# 反向传播,更新权重,清除梯度
avg_loss.backward()
optimizer.step()
optimizer.clear_grad()
model.eval()
accuracies = []
losses = []
# 运行模型前向计算,得到预测值
logits = model(feature)
# 二分类,sigmoid计算后的结果以0.5为阈值分两个类别
# 计算sigmoid后的预测概率,进行loss计算
pred = logits
loss = F.binary_cross_entropy_with_logits(logits, label)
# 计算预测概率小于0.5的类别
pred2 = pred * (-1.0) + 1.0
# 得到两个类别的预测概率,并沿第一个维度级联
pred = paddle.concat([pred2, pred], axis=1)
acc = paddle.metric.accuracy(pred, paddle.cast(label, dtype='int64'))
accuracies.append(acc.numpy())
losses.append(loss.numpy())
print("[validation] accuracy/loss: {:.4f}/{:.4f}".format(np.mean(accuracies), np.mean(losses)))
model.train()
# paddle.save(model.state_dict(), 'model/lolmp{}_{}.pdparams'.format(epoch,acc.numpy()))
mynet = MyNet()
train_pm(
mynet,
paddle.optimizer.Adam(
parameters=mynet.parameters(),
learning_rate=0.005
),
np.array(
pd.read_csv('pretrain/pcatrain.csv')
).astype('float32'),
np.array(
pd.read_csv('data/train.csv')['win']
).astype('float32'),
400
)
start training
[validation] accuracy/loss: 0.5128/0.6676
[validation] accuracy/loss: 0.8079/0.6187
[validation] accuracy/loss: 0.8114/0.5981
[validation] accuracy/loss: 0.8112/0.5925
[validation] accuracy/loss: 0.8110/0.5919
[validation] accuracy/loss: 0.8105/0.5920
[validation] accuracy/loss: 0.8106/0.5913
[validation] accuracy/loss: 0.7105/0.6138
[validation] accuracy/loss: 0.8112/0.5916
[validation] accuracy/loss: 0.8116/0.5929
[validation] accuracy/loss: 0.8118/0.5949
[validation] accuracy/loss: 0.5492/0.7224
[validation] accuracy/loss: 0.8119/0.5934
[validation] accuracy/loss: 0.8120/0.5932
[validation] accuracy/loss: 0.8122/0.5930
[validation] accuracy/loss: 0.8125/0.5928
[validation] accuracy/loss: 0.8127/0.5923
[validation] accuracy/loss: 0.8131/0.5917
[validation] accuracy/loss: 0.8129/0.5909
[validation] accuracy/loss: 0.8123/0.5899
[validation] accuracy/loss: 0.8070/0.5939
[validation] accuracy/loss: 0.8122/0.5894
[validation] accuracy/loss: 0.8127/0.5894
[validation] accuracy/loss: 0.8132/0.5894
[validation] accuracy/loss: 0.8134/0.5892
[validation] accuracy/loss: 0.8135/0.5890
[validation] accuracy/loss: 0.8138/0.5888
[validation] accuracy/loss: 0.8139/0.5885
[validation] accuracy/loss: 0.8141/0.5881
[validation] accuracy/loss: 0.8141/0.5877
[validation] accuracy/loss: 0.8143/0.5874
[validation] accuracy/loss: 0.8134/0.5871
[validation] accuracy/loss: 0.8123/0.5870
[validation] accuracy/loss: 0.8115/0.5866
[validation] accuracy/loss: 0.8118/0.5860
[validation] accuracy/loss: 0.8130/0.5854
[validation] accuracy/loss: 0.8140/0.5849
[validation] accuracy/loss: 0.8143/0.5844
[validation] accuracy/loss: 0.8133/0.5840
[validation] accuracy/loss: 0.8140/0.5835
[validation] accuracy/loss: 0.8165/0.5831
[validation] accuracy/loss: 0.8169/0.5827
[validation] accuracy/loss: 0.8157/0.5824
[validation] accuracy/loss: 0.8185/0.5822
[validation] accuracy/loss: 0.8189/0.5820
[validation] accuracy/loss: 0.8174/0.5819
[validation] accuracy/loss: 0.8191/0.5817
[validation] accuracy/loss: 0.8198/0.5815
[validation] accuracy/loss: 0.8186/0.5814
[validation] accuracy/loss: 0.8202/0.5812
[validation] accuracy/loss: 0.8200/0.5810
[validation] accuracy/loss: 0.8191/0.5809
[validation] accuracy/loss: 0.8207/0.5807
[validation] accuracy/loss: 0.8198/0.5806
[validation] accuracy/loss: 0.8196/0.5805
[validation] accuracy/loss: 0.8213/0.5804
[validation] accuracy/loss: 0.8198/0.5802
[validation] accuracy/loss: 0.8202/0.5800
[validation] accuracy/loss: 0.8220/0.5799
[validation] accuracy/loss: 0.8204/0.5798
[validation] accuracy/loss: 0.8235/0.5797
[validation] accuracy/loss: 0.8219/0.5794
[validation] accuracy/loss: 0.8241/0.5792
[validation] accuracy/loss: 0.8238/0.5791
[validation] accuracy/loss: 0.8248/0.5789
[validation] accuracy/loss: 0.8250/0.5787
[validation] accuracy/loss: 0.8273/0.5786
[validation] accuracy/loss: 0.8191/0.5793
[validation] accuracy/loss: 0.8311/0.5812
[validation] accuracy/loss: 0.8311/0.5816
[validation] accuracy/loss: 0.8311/0.5794
[validation] accuracy/loss: 0.8211/0.5788
[validation] accuracy/loss: 0.8121/0.5805
[validation] accuracy/loss: 0.8299/0.5782
[validation] accuracy/loss: 0.8339/0.5797
[validation] accuracy/loss: 0.8339/0.5799
[validation] accuracy/loss: 0.8332/0.5787
[validation] accuracy/loss: 0.8291/0.5777
[validation] accuracy/loss: 0.8209/0.5783
[validation] accuracy/loss: 0.8206/0.5784
[validation] accuracy/loss: 0.8290/0.5772
[validation] accuracy/loss: 0.8332/0.5775
[validation] accuracy/loss: 0.8350/0.5779
[validation] accuracy/loss: 0.8349/0.5775
[validation] accuracy/loss: 0.8327/0.5767
[validation] accuracy/loss: 0.8278/0.5768
[validation] accuracy/loss: 0.8258/0.5771
[validation] accuracy/loss: 0.8307/0.5765
[validation] accuracy/loss: 0.8343/0.5764
[validation] accuracy/loss: 0.8363/0.5766
[validation] accuracy/loss: 0.8364/0.5764
[validation] accuracy/loss: 0.8348/0.5759
[validation] accuracy/loss: 0.8310/0.5761
[validation] accuracy/loss: 0.8312/0.5760
[validation] accuracy/loss: 0.8349/0.5757
[validation] accuracy/loss: 0.8370/0.5758
[validation] accuracy/loss: 0.8375/0.5757
[validation] accuracy/loss: 0.8356/0.5755
[validation] accuracy/loss: 0.8333/0.5755
[validation] accuracy/loss: 0.8334/0.5754
[validation] accuracy/loss: 0.8361/0.5752
[validation] accuracy/loss: 0.8378/0.5753
[validation] accuracy/loss: 0.8371/0.5751
[validation] accuracy/loss: 0.8358/0.5750
[validation] accuracy/loss: 0.8352/0.5750
[validation] accuracy/loss: 0.8364/0.5749
[validation] accuracy/loss: 0.8378/0.5749
[validation] accuracy/loss: 0.8372/0.5747
[validation] accuracy/loss: 0.8361/0.5747
[validation] accuracy/loss: 0.8367/0.5746
[validation] accuracy/loss: 0.8385/0.5745
[validation] accuracy/loss: 0.8383/0.5744
[validation] accuracy/loss: 0.8373/0.5744
[validation] accuracy/loss: 0.8381/0.5743
[validation] accuracy/loss: 0.8393/0.5742
[validation] accuracy/loss: 0.8384/0.5741
[validation] accuracy/loss: 0.8378/0.5741
[validation] accuracy/loss: 0.8394/0.5740
[validation] accuracy/loss: 0.8397/0.5739
[validation] accuracy/loss: 0.8386/0.5738
[validation] accuracy/loss: 0.8394/0.5738
[validation] accuracy/loss: 0.8402/0.5737
[validation] accuracy/loss: 0.8392/0.5736
[validation] accuracy/loss: 0.8407/0.5735
[validation] accuracy/loss: 0.8401/0.5734
[validation] accuracy/loss: 0.8407/0.5733
[validation] accuracy/loss: 0.8406/0.5732
[validation] accuracy/loss: 0.8417/0.5732
[validation] accuracy/loss: 0.8388/0.5733
[validation] accuracy/loss: 0.8430/0.5739
[validation] accuracy/loss: 0.8332/0.5744
[validation] accuracy/loss: 0.8351/0.5745
[validation] accuracy/loss: 0.8370/0.5742
[validation] accuracy/loss: 0.8367/0.5739
[validation] accuracy/loss: 0.8384/0.5738
[validation] accuracy/loss: 0.8406/0.5738
[validation] accuracy/loss: 0.8399/0.5734
[validation] accuracy/loss: 0.8383/0.5734
[validation] accuracy/loss: 0.8406/0.5732
[validation] accuracy/loss: 0.8421/0.5731
[validation] accuracy/loss: 0.8410/0.5730
[validation] accuracy/loss: 0.8416/0.5729
[validation] accuracy/loss: 0.8422/0.5728
[validation] accuracy/loss: 0.8413/0.5727
[validation] accuracy/loss: 0.8418/0.5725
[validation] accuracy/loss: 0.8440/0.5725
[validation] accuracy/loss: 0.8431/0.5723
[validation] accuracy/loss: 0.8442/0.5721
[validation] accuracy/loss: 0.8448/0.5720
[validation] accuracy/loss: 0.8431/0.5719
[validation] accuracy/loss: 0.8453/0.5717
[validation] accuracy/loss: 0.8445/0.5716
[validation] accuracy/loss: 0.8457/0.5715
[validation] accuracy/loss: 0.8446/0.5715
[validation] accuracy/loss: 0.8464/0.5714
[validation] accuracy/loss: 0.8419/0.5717
[validation] accuracy/loss: 0.8472/0.5727
[validation] accuracy/loss: 0.8378/0.5728
[validation] accuracy/loss: 0.8369/0.5731
[validation] accuracy/loss: 0.8413/0.5727
[validation] accuracy/loss: 0.8431/0.5720
[validation] accuracy/loss: 0.8409/0.5726
[validation] accuracy/loss: 0.8462/0.5720
[validation] accuracy/loss: 0.8439/0.5717
[validation] accuracy/loss: 0.8411/0.5720
[validation] accuracy/loss: 0.8453/0.5713
[validation] accuracy/loss: 0.8469/0.5718
[validation] accuracy/loss: 0.8435/0.5714
[validation] accuracy/loss: 0.8454/0.5711
[validation] accuracy/loss: 0.8461/0.5711
[validation] accuracy/loss: 0.8448/0.5710
[validation] accuracy/loss: 0.8477/0.5709
[validation] accuracy/loss: 0.8467/0.5708
[validation] accuracy/loss: 0.8470/0.5706
[validation] accuracy/loss: 0.8481/0.5706
[validation] accuracy/loss: 0.8463/0.5704
[validation] accuracy/loss: 0.8480/0.5707
[validation] accuracy/loss: 0.8475/0.5703
[validation] accuracy/loss: 0.8477/0.5705
[validation] accuracy/loss: 0.8490/0.5704
[validation] accuracy/loss: 0.8470/0.5700
[validation] accuracy/loss: 0.8485/0.5706
[validation] accuracy/loss: 0.8462/0.5704
[validation] accuracy/loss: 0.8499/0.5700
[validation] accuracy/loss: 0.8472/0.5703
[validation] accuracy/loss: 0.8495/0.5697
[validation] accuracy/loss: 0.8507/0.5697
[validation] accuracy/loss: 0.8455/0.5702
[validation] accuracy/loss: 0.8503/0.5707
[validation] accuracy/loss: 0.8404/0.5715
[validation] accuracy/loss: 0.8451/0.5710
[validation] accuracy/loss: 0.8486/0.5703
[validation] accuracy/loss: 0.8442/0.5710
[validation] accuracy/loss: 0.8505/0.5698
[validation] accuracy/loss: 0.8469/0.5706
[validation] accuracy/loss: 0.8457/0.5706
[validation] accuracy/loss: 0.8503/0.5694
[validation] accuracy/loss: 0.8501/0.5706
[validation] accuracy/loss: 0.8445/0.5704
[validation] accuracy/loss: 0.8462/0.5708
[validation] accuracy/loss: 0.8478/0.5702
[validation] accuracy/loss: 0.8445/0.5703
[validation] accuracy/loss: 0.8507/0.5702
[validation] accuracy/loss: 0.8500/0.5695
[validation] accuracy/loss: 0.8476/0.5698
[validation] accuracy/loss: 0.8501/0.5694
[validation] accuracy/loss: 0.8504/0.5692
[validation] accuracy/loss: 0.8508/0.5691
[validation] accuracy/loss: 0.8508/0.5691
[validation] accuracy/loss: 0.8518/0.5688
[validation] accuracy/loss: 0.8506/0.5690
[validation] accuracy/loss: 0.8520/0.5690
[validation] accuracy/loss: 0.8479/0.5692
[validation] accuracy/loss: 0.8522/0.5695
[validation] accuracy/loss: 0.8473/0.5694
[validation] accuracy/loss: 0.8529/0.5687
[validation] accuracy/loss: 0.8531/0.5684
[validation] accuracy/loss: 0.8503/0.5686
[validation] accuracy/loss: 0.8525/0.5696
[validation] accuracy/loss: 0.8435/0.5703
[validation] accuracy/loss: 0.8494/0.5695
[validation] accuracy/loss: 0.8518/0.5690
[validation] accuracy/loss: 0.8476/0.5698
[validation] accuracy/loss: 0.8525/0.5687
[validation] accuracy/loss: 0.8507/0.5691
[validation] accuracy/loss: 0.8511/0.5686
[validation] accuracy/loss: 0.8526/0.5692
[validation] accuracy/loss: 0.8469/0.5694
[validation] accuracy/loss: 0.8495/0.5698
[validation] accuracy/loss: 0.8510/0.5686
[validation] accuracy/loss: 0.8505/0.5692
[validation] accuracy/loss: 0.8511/0.5687
[validation] accuracy/loss: 0.8515/0.5691
[validation] accuracy/loss: 0.8517/0.5684
[validation] accuracy/loss: 0.8506/0.5689
[validation] accuracy/loss: 0.8525/0.5686
[validation] accuracy/loss: 0.8504/0.5686
[validation] accuracy/loss: 0.8524/0.5681
[validation] accuracy/loss: 0.8538/0.5682
[validation] accuracy/loss: 0.8504/0.5684
[validation] accuracy/loss: 0.8533/0.5687
[validation] accuracy/loss: 0.8527/0.5678
[validation] accuracy/loss: 0.8536/0.5678
[validation] accuracy/loss: 0.8546/0.5678
[validation] accuracy/loss: 0.8502/0.5684
[validation] accuracy/loss: 0.8545/0.5687
[validation] accuracy/loss: 0.8491/0.5686
[validation] accuracy/loss: 0.8549/0.5677
[validation] accuracy/loss: 0.8542/0.5675
[validation] accuracy/loss: 0.8523/0.5679
[validation] accuracy/loss: 0.8553/0.5677
[validation] accuracy/loss: 0.8522/0.5679
[validation] accuracy/loss: 0.8549/0.5675
[validation] accuracy/loss: 0.8538/0.5677
[validation] accuracy/loss: 0.8539/0.5675
[validation] accuracy/loss: 0.8550/0.5673
[validation] accuracy/loss: 0.8533/0.5677
[validation] accuracy/loss: 0.8536/0.5678
[validation] accuracy/loss: 0.8546/0.5673
[validation] accuracy/loss: 0.8540/0.5680
[validation] accuracy/loss: 0.8503/0.5682
[validation] accuracy/loss: 0.8517/0.5697
[validation] accuracy/loss: 0.8543/0.5676
[validation] accuracy/loss: 0.8467/0.5694
[validation] accuracy/loss: 0.8538/0.5685
[validation] accuracy/loss: 0.8511/0.5689
[validation] accuracy/loss: 0.8483/0.5689
[validation] accuracy/loss: 0.8538/0.5683
[validation] accuracy/loss: 0.8530/0.5675
[validation] accuracy/loss: 0.8540/0.5678
[validation] accuracy/loss: 0.8561/0.5671
[validation] accuracy/loss: 0.8517/0.5678
[validation] accuracy/loss: 0.8551/0.5680
[validation] accuracy/loss: 0.8532/0.5676
[validation] accuracy/loss: 0.8517/0.5679
[validation] accuracy/loss: 0.8550/0.5680
[validation] accuracy/loss: 0.8504/0.5680
[validation] accuracy/loss: 0.8553/0.5677
[validation] accuracy/loss: 0.8554/0.5672
[validation] accuracy/loss: 0.8503/0.5680
[validation] accuracy/loss: 0.8548/0.5685
[validation] accuracy/loss: 0.8532/0.5675
[validation] accuracy/loss: 0.8536/0.5672
[validation] accuracy/loss: 0.8553/0.5681
[validation] accuracy/loss: 0.8511/0.5679
[validation] accuracy/loss: 0.8538/0.5675
[validation] accuracy/loss: 0.8565/0.5673
[validation] accuracy/loss: 0.8500/0.5681
[validation] accuracy/loss: 0.8554/0.5670
[validation] accuracy/loss: 0.8559/0.5675
[validation] accuracy/loss: 0.8516/0.5677
[validation] accuracy/loss: 0.8552/0.5674
[validation] accuracy/loss: 0.8550/0.5675
[validation] accuracy/loss: 0.8515/0.5677
[validation] accuracy/loss: 0.8553/0.5672
[validation] accuracy/loss: 0.8567/0.5666
[validation] accuracy/loss: 0.8540/0.5670
[validation] accuracy/loss: 0.8564/0.5668
[validation] accuracy/loss: 0.8570/0.5663
[validation] accuracy/loss: 0.8572/0.5664
[validation] accuracy/loss: 0.8549/0.5667
[validation] accuracy/loss: 0.8567/0.5663
[validation] accuracy/loss: 0.8568/0.5666
[validation] accuracy/loss: 0.8566/0.5663
[validation] accuracy/loss: 0.8573/0.5662
[validation] accuracy/loss: 0.8569/0.5663
[validation] accuracy/loss: 0.8541/0.5673
[validation] accuracy/loss: 0.8527/0.5676
[validation] accuracy/loss: 0.8570/0.5668
[validation] accuracy/loss: 0.8532/0.5673
[validation] accuracy/loss: 0.8568/0.5662
[validation] accuracy/loss: 0.8566/0.5672
[validation] accuracy/loss: 0.8520/0.5675
[validation] accuracy/loss: 0.8538/0.5679
[validation] accuracy/loss: 0.8557/0.5669
[validation] accuracy/loss: 0.8565/0.5666
[validation] accuracy/loss: 0.8527/0.5682
[validation] accuracy/loss: 0.8558/0.5671
[validation] accuracy/loss: 0.8520/0.5686
[validation] accuracy/loss: 0.8507/0.5681
[validation] accuracy/loss: 0.8507/0.5688
[validation] accuracy/loss: 0.8528/0.5677
[validation] accuracy/loss: 0.8528/0.5677
[validation] accuracy/loss: 0.8557/0.5672
[validation] accuracy/loss: 0.8543/0.5674
[validation] accuracy/loss: 0.8536/0.5672
[validation] accuracy/loss: 0.8563/0.5669
[validation] accuracy/loss: 0.8567/0.5664
[validation] accuracy/loss: 0.8558/0.5667
[validation] accuracy/loss: 0.8570/0.5661
[validation] accuracy/loss: 0.8568/0.5663
[validation] accuracy/loss: 0.8563/0.5664
[validation] accuracy/loss: 0.8569/0.5662
[validation] accuracy/loss: 0.8568/0.5663
[validation] accuracy/loss: 0.8582/0.5658
[validation] accuracy/loss: 0.8569/0.5659
[validation] accuracy/loss: 0.8589/0.5659
[validation] accuracy/loss: 0.8555/0.5662
[validation] accuracy/loss: 0.8587/0.5663
[validation] accuracy/loss: 0.8537/0.5667
[validation] accuracy/loss: 0.8588/0.5659
[validation] accuracy/loss: 0.8568/0.5659
[validation] accuracy/loss: 0.8586/0.5655
[validation] accuracy/loss: 0.8594/0.5657
[validation] accuracy/loss: 0.8516/0.5673
[validation] accuracy/loss: 0.8557/0.5674
[validation] accuracy/loss: 0.8584/0.5661
[validation] accuracy/loss: 0.8503/0.5677
[validation] accuracy/loss: 0.8581/0.5661
[validation] accuracy/loss: 0.8589/0.5658
[validation] accuracy/loss: 0.8538/0.5666
[validation] accuracy/loss: 0.8590/0.5658
[validation] accuracy/loss: 0.8578/0.5659
[validation] accuracy/loss: 0.8573/0.5656
[validation] accuracy/loss: 0.8582/0.5666
[validation] accuracy/loss: 0.8529/0.5669
[validation] accuracy/loss: 0.8561/0.5666
[validation] accuracy/loss: 0.8583/0.5659
[validation] accuracy/loss: 0.8563/0.5660
[validation] accuracy/loss: 0.8580/0.5664
[validation] accuracy/loss: 0.8584/0.5655
[validation] accuracy/loss: 0.8578/0.5658
[validation] accuracy/loss: 0.8573/0.5664
[validation] accuracy/loss: 0.8574/0.5657
[validation] accuracy/loss: 0.8583/0.5660
[validation] accuracy/loss: 0.8571/0.5657
[validation] accuracy/loss: 0.8586/0.5659
[validation] accuracy/loss: 0.8563/0.5663
[validation] accuracy/loss: 0.8588/0.5654
[validation] accuracy/loss: 0.8584/0.5657
[validation] accuracy/loss: 0.8569/0.5656
[validation] accuracy/loss: 0.8586/0.5666
[validation] accuracy/loss: 0.8537/0.5667
[validation] accuracy/loss: 0.8567/0.5662
[validation] accuracy/loss: 0.8584/0.5656
[validation] accuracy/loss: 0.8570/0.5658
[validation] accuracy/loss: 0.8590/0.5655
[validation] accuracy/loss: 0.8581/0.5656
[validation] accuracy/loss: 0.8586/0.5655
[validation] accuracy/loss: 0.8604/0.5651
[validation] accuracy/loss: 0.8559/0.5658
[validation] accuracy/loss: 0.8593/0.5657
[validation] accuracy/loss: 0.8582/0.5654
[validation] accuracy/loss: 0.8598/0.5649
[validation] accuracy/loss: 0.8598/0.5655
[validation] accuracy/loss: 0.8570/0.5655
[validation] accuracy/loss: 0.8599/0.5655
[validation] accuracy/loss: 0.8565/0.5655
[validation] accuracy/loss: 0.8604/0.5651
[validation] accuracy/loss: 0.8595/0.5653
[validation] accuracy/loss: 0.8584/0.5653
[validation] accuracy/loss: 0.8602/0.5651
[validation] accuracy/loss: 0.8564/0.5659
[validation] accuracy/loss: 0.8594/0.5649
[validation] accuracy/loss: 0.8596/0.5654
[validation] accuracy/loss: 0.8540/0.5665
[validation] accuracy/loss: 0.8599/0.5656
[validation] accuracy/loss: 0.8580/0.5659
[validation] accuracy/loss: 0.8548/0.5660
[validation] accuracy/loss: 0.8591/0.5663
- 预测结果
pd.DataFrame(
np.where(
mynet(
paddle.to_tensor(
np.array(
pd.read_csv('pretrain/pcatest.csv')
).astype('float32')
)
).numpy() > 0.5, 1, 0
),
columns=['win']
).to_csv('submission.csv',index=False)
!zip result/submission.zip submission.csv
!rm submission.csv
adding: submission.csv (deflated 90%)
[validation] accuracy/loss: 0.8599/0.5656
[validation] accuracy/loss: 0.8580/0.5659
[validation] accuracy/loss: 0.8548/0.5660
[validation] accuracy/loss: 0.8591/0.5663
- 预测结果
pd.DataFrame(
np.where(
mynet(
paddle.to_tensor(
np.array(
pd.read_csv('pretrain/pcatest.csv')
).astype('float32')
)
).numpy() > 0.5, 1, 0
),
columns=['win']
).to_csv('submission.csv',index=False)
!zip result/submission.zip submission.csv
!rm submission.csv
adding: submission.csv (deflated 90%)
项目总结
事实上项目仍然有很多能够改进的地方
-
模型可以利用
Network-in-Network
进一步优化网络结构 -
尽管
180000
条样本可以一次性训练,修改训练批数仍有可能进步 -
项目并没有做明确的训练集/测试集划分,可以尝试采用交叉验证来减少过拟合