数据准备
直接在Anaconda命令行里用kaggle提供的命令下载,参考fastbook中09_tabula的API方法没有成功。
用pandas读取一下数据。
path = URLs.path('titanic')
train_path = path/'train.csv'
test_path = path/'test.csv'
df = pd.read_csv(train_path, low_memory=False,skipinitialspace=True)
test_df = pd.read_csv(test_path, low_memory=False,skipinitialspace=True)
生成DataLoaders
cat_names = [ 'Pclass', 'SibSp',
'Parch', 'Fare', 'Cabin', 'Embarked', 'Sex']
cont_names = ['Age','PassengerId']
procs = [Categorify, FillMissing, Normalize]
dls = TabularDataLoaders.from_csv(train_path, path, procs=procs, cat_names=cat_names, cont_names=cont_names,
y_names=dep_var, valid_idx=list(range(600,800)), bs=64,y_block = CategoryBlock)
其中要特别注意的是预测中的y_name,这个数据集中是dep_var = 'Survived'是不可以加入cat_names 和cont_names中的,第一次用的时候直接对df.columns复制粘贴,因为加入了y_name到离散型变量中引发了莫名其妙的pandas bug,找了将近一个小时才发现是这个原因。
开始训练
learn = tabular_learner(dls,metrics=accuracy)
learn.fit_one_cycle(100)
这里就采用了默认的tabular_learner,结果如下:
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 0.687910 | 0.672002 | 0.640000 | 00:01 |
1 | 0.679152 | 0.665265 | 0.640000 | 00:00 |
2 | 0.670670 | 0.658679 | 0.680000 | 00:00 |
3 | 0.663144 | 0.649106 | 0.670000 | 00:00 |
4 | 0.654119 | 0.642387 | 0.670000 | 00:00 |
5 | 0.642554 | 0.643897 | 0.610000 | 00:00 |
6 | 0.627525 | 0.626449 | 0.615000 | 00:00 |
7 | 0.608022 | 0.594794 | 0.680000 | 00:00 |
8 | 0.579566 | 0.574257 | 0.710000 | 00:00 |
9 | 0.540928 | 0.532365 | 0.725000 | 00:00 |
10 | 0.495352 | 0.487120 | 0.755000 | 00:00 |
11 | 0.449479 | 0.445067 | 0.770000 | 00:00 |
12 | 0.407328 | 0.510866 | 0.755000 | 00:00 |
13 | 0.372619 | 0.512942 | 0.775000 | 00:00 |
14 | 0.344380 | 0.554082 | 0.775000 | 00:00 |
15 | 0.316252 | 0.487315 | 0.775000 | 00:00 |
16 | 0.290813 | 0.525596 | 0.755000 | 00:00 |
17 | 0.266976 | 0.606796 | 0.775000 | 00:00 |
18 | 0.248752 | 0.604665 | 0.755000 | 00:00 |
19 | 0.232290 | 0.590609 | 0.750000 | 00:00 |
20 | 0.218391 | 0.578665 | 0.745000 | 00:00 |
21 | 0.203587 | 0.586581 | 0.750000 | 00:00 |
22 | 0.189189 | 0.628336 | 0.705000 | 00:00 |
23 | 0.180911 | 0.689595 | 0.715000 | 00:00 |
24 | 0.170023 | 0.637040 | 0.735000 | 00:00 |
25 | 0.161536 | 0.630965 | 0.765000 | 00:00 |
26 | 0.152605 | 0.591396 | 0.740000 | 00:00 |
27 | 0.146294 | 0.648590 | 0.730000 | 00:00 |
28 | 0.142579 | 0.688091 | 0.730000 | 00:00 |
29 | 0.135967 | 0.781510 | 0.715000 | 00:00 |
30 | 0.127739 | 0.744347 | 0.740000 | 00:00 |
31 | 0.123255 | 0.655497 | 0.745000 | 00:00 |
32 | 0.119241 | 0.715305 | 0.745000 | 00:00 |
33 | 0.114936 | 0.902263 | 0.725000 | 00:00 |
34 | 0.112255 | 0.839763 | 0.720000 | 00:00 |
35 | 0.106522 | 0.742290 | 0.730000 | 00:00 |
36 | 0.102651 | 0.789978 | 0.760000 | 00:00 |
37 | 0.100017 | 0.771653 | 0.730000 | 00:00 |
38 | 0.096346 | 0.885213 | 0.720000 | 00:00 |
39 | 0.094071 | 0.821706 | 0.745000 | 00:00 |
40 | 0.089969 | 0.728642 | 0.750000 | 00:00 |
41 | 0.086300 | 0.764377 | 0.735000 | 00:00 |
42 | 0.080375 | 0.737520 | 0.750000 | 00:00 |
43 | 0.075749 | 0.840410 | 0.735000 | 00:00 |
44 | 0.070874 | 0.877053 | 0.735000 | 00:00 |
45 | 0.070402 | 0.879526 | 0.725000 | 00:00 |
46 | 0.069162 | 0.885928 | 0.730000 | 00:00 |
47 | 0.065940 | 0.951272 | 0.755000 | 00:00 |
48 | 0.064935 | 0.835828 | 0.760000 | 00:00 |
49 | 0.062491 | 0.956926 | 0.705000 | 00:00 |
50 | 0.062818 | 0.995981 | 0.730000 | 00:00 |
51 | 0.060378 | 0.943567 | 0.745000 | 00:00 |
52 | 0.061401 | 1.084070 | 0.710000 | 00:00 |
53 | 0.061499 | 0.970715 | 0.725000 | 00:00 |
54 | 0.061217 | 1.169610 | 0.735000 | 00:00 |
55 | 0.061242 | 0.914683 | 0.740000 | 00:00 |
56 | 0.058386 | 0.955398 | 0.740000 | 00:00 |
57 | 0.055345 | 0.921871 | 0.720000 | 00:00 |
58 | 0.052027 | 0.915062 | 0.735000 | 00:00 |
59 | 0.047930 | 0.933616 | 0.735000 | 00:00 |
60 | 0.043968 | 0.998685 | 0.735000 | 00:00 |
61 | 0.042489 | 0.962408 | 0.745000 | 00:00 |
62 | 0.041475 | 1.049256 | 0.720000 | 00:00 |
63 | 0.040391 | 1.045698 | 0.720000 | 00:00 |
64 | 0.042820 | 1.019700 | 0.735000 | 00:00 |
65 | 0.043163 | 0.971371 | 0.725000 | 00:00 |
66 | 0.039586 | 1.006120 | 0.745000 | 00:00 |
67 | 0.039305 | 1.085823 | 0.705000 | 00:00 |
68 | 0.038155 | 1.167134 | 0.700000 | 00:00 |
69 | 0.035776 | 1.143821 | 0.700000 | 00:00 |
70 | 0.035192 | 1.146205 | 0.725000 | 00:00 |
71 | 0.034136 | 1.175187 | 0.710000 | 00:00 |
72 | 0.032035 | 1.082435 | 0.730000 | 00:00 |
73 | 0.031123 | 1.080388 | 0.725000 | 00:00 |
74 | 0.030651 | 1.147222 | 0.710000 | 00:00 |
75 | 0.030057 | 1.151547 | 0.730000 | 00:00 |
76 | 0.028681 | 1.130530 | 0.730000 | 00:00 |
77 | 0.027378 | 1.152099 | 0.715000 | 00:00 |
78 | 0.028083 | 1.134053 | 0.725000 | 00:00 |
79 | 0.027603 | 1.147027 | 0.725000 | 00:00 |
80 | 0.026748 | 1.164077 | 0.715000 | 00:00 |
81 | 0.027223 | 1.200746 | 0.715000 | 00:00 |
82 | 0.025508 | 1.146767 | 0.715000 | 00:00 |
83 | 0.024893 | 1.130284 | 0.720000 | 00:00 |
84 | 0.025231 | 1.126982 | 0.720000 | 00:00 |
85 | 0.025987 | 1.133423 | 0.720000 | 00:00 |
86 | 0.025825 | 1.143805 | 0.715000 | 00:00 |
87 | 0.025462 | 1.144619 | 0.725000 | 00:00 |
88 | 0.024687 | 1.134826 | 0.720000 | 00:00 |
89 | 0.025501 | 1.134439 | 0.735000 | 00:00 |
90 | 0.024548 | 1.137204 | 0.725000 | 00:00 |
91 | 0.023910 | 1.149608 | 0.730000 | 00:00 |
92 | 0.023950 | 1.150210 | 0.730000 | 00:00 |
93 | 0.024212 | 1.141803 | 0.730000 | 00:00 |
94 | 0.025319 | 1.147405 | 0.725000 | 00:00 |
95 | 0.024329 | 1.148225 | 0.730000 | 00:00 |
96 | 0.023654 | 1.158221 | 0.725000 | 00:00 |
97 | 0.024523 | 1.170079 | 0.715000 | 00:00 |
98 | 0.024479 | 1.165972 | 0.720000 | 00:00 |
99 | 0.023362 | 1.153834 | 0.730000 | 00:00 |
可以看出很不幸,验证集的loss远大于训练集的loss,标准的OverFitting发生了且很严重,下一步准备纳入L2正则化之类的东西改善一下,准确率达到了73%,勉强可以接受。
模型预测
test_df = pd.read_csv(test_path, low_memory=False,skipinitialspace=True)
#df.head()
dl = learn.dls.test_dl(test_df)
preds=learn.get_preds(dl=dl)
提交比赛结果
得到了0.7的得分(最高1.00000)
改进计划
体感修一修OverFitting性能提高的空间还是挺大的,目标0.9分,改进之后下次上船Github贴上Github仓库