简单介绍:
赛题来源是天池大数据的 “商场中精确定位用户所在店铺”。原数据有114万条,计算起来非常困难。为了让初学者有一个更好的学习体验,也更加基础,我将数据集缩小了之后放在这里,密码:ndfd。供大家下载。
import pandas as pd
import xgboost as xgb
from sklearn import preprocessing
train = pd.read_csv('train.csv')
tests = pd.read_csv('test.csv')
train
user_id | shop_id | time_stamp | longitude | latitude | wifi_id1 | wifi_strong1 | con_sta1 | wifi_id2 | wifi_strong2 | ... | con_sta7 | wifi_id8 | wifi_strong8 | con_sta8 | wifi_id9 | wifi_strong9 | con_sta9 | wifi_id10 | wifi_strong10 | con_sta10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | u_376 | s_2871718 | 2017/8/6 21:20 | 122.308291 | 32.088040 | b_6396480 | -67 | False | b_41124514 | -86 | ... | FALSE | b_56326644 | -89.0 | FALSE | b_56328155 | -77.0 | FALSE | b_5857369 | -55.0 | false\n |
1 | u_376 | s_2871718 | 2017/8/6 21:20 | 122.308162 | 32.087970 | b_6396480 | -67 | False | b_56328155 | -73 | ... | FALSE | b_6396479 | -57.0 | FALSE | b_31100514 | -89.0 | FALSE | b_5857369 | -57.0 | false\n |
2 | u_1041 | s_181637 | 2017/8/2 13:10 | 117.365255 | 40.638214 | b_8006367 | -78 | False | b_2485110 | -52 | ... | FALSE | b_8006521 | -74.0 | FALSE | b_35013153 | -56.0 | FALSE | b_37608251 | -84.0 | false\n |
3 | u_1158 | s_609470 | 2017/8/13 12:30 | 121.134451 | 31.197416 | b_26250579 | -73 | False | b_26250580 | -64 | ... | FALSE | b_30424471 | -60.0 | FALSE | b_26250578 | -72.0 | FALSE | b_29510856 | -80.0 | false\n |
4 | u_1654 | s_3816766 | 2017/8/25 19:50 | 122.255867 | 31.351320 | b_39004150 | -66 | False | b_39004148 | -58 | ... | FALSE | b_6805211 | -80.0 | FALSE | b_1845687 | -72.0 | FALSE | b_21685901 | -91.0 | false\n |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
495 | u_83642 | s_398021 | 2017/8/24 20:10 | 121.731091 | 32.602940 | b_40778712 | -36 | True | b_40778713 | -53 | ... | FALSE | b_30772238 | -79.0 | FALSE | b_19291072 | -77.0 | FALSE | b_52688309 | -63.0 | false\n |
496 | u_84447 | s_386382 | 2017/8/3 18:10 | 111.341364 | 31.216452 | b_13303539 | -62 | False | b_47973407 | -70 | ... | FALSE | b_13299121 | -62.0 | FALSE | b_47973408 | -66.0 | FALSE | b_56326651 | -49.0 | false\n |
497 | u_84524 | s_322471 | 2017/8/12 18:20 | 122.596036 | 31.581866 | b_54461743 | -46 | False | b_38143992 | -73 | ... | FALSE | b_54461973 | -45.0 | FALSE | b_2837595 | -58.0 | FALSE | b_35405625 | -73.0 | false\n |
498 | u_84860 | s_390053 | 2017/8/6 21:00 | 121.365752 | 32.316147 | b_7962419 | -51 | False | b_46165431 | -65 | ... | FALSE | b_26725258 | -80.0 | FALSE | b_30465621 | -82.0 | FALSE | b_22564180 | -53.0 | true\n |
499 | u_83642 | s_398021 | 2017/8/24 20:10 | 121.731091 | 32.602940 | b_40778712 | -36 | True | b_40778713 | -53 | ... | FALSE | b_30772238 | -79.0 | FALSE | b_19291072 | -77.0 | FALSE | b_52688309 | -63.0 | false\n |
500 rows × 35 columns
tests
row_id | shop_id | user_id | time_stamp | longitude | latitude | wifi_id1 | wifi_strong1 | con_sta1 | wifi_id2 | ... | con_sta7 | wifi_id8 | wifi_strong8 | con_sta8 | wifi_id9 | wifi_strong9 | con_sta9 | wifi_id10 | wifi_strong10 | con_sta10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 118742 | NaN | u_30097142 | 2017/9/5 13:00 | 122.141011 | 39.818847 | b_34366982 | -82 | False | b_37756289 | ... | FALSE | b_28978909 | -62.0 | FALSE | b_21518966 | -68.0 | FALSE | b_13748229 | -72.0 | false\n |
1 | 118743 | NaN | u_30097803 | 2017/9/6 13:10 | 118.191907 | 32.855858 | b_36722251 | -81 | False | b_10537579 | ... | FALSE | b_21694478 | -80.0 | FALSE | b_44551973 | -72.0 | FALSE | b_21694477 | -85.0 | false\n |
2 | 118744 | NaN | u_30097889 | 2017/9/6 17:40 | 119.192110 | 32.424667 | b_30026291 | -74 | False | b_30026290 | ... | FALSE | b_50235613 | -75.0 | FALSE | b_17955238 | -85.0 | FALSE | b_40924464 | -54.0 | false\n |
3 | 118745 | NaN | u_30098996 | 2017/9/3 12:10 | 120.612201 | 34.055249 | b_33412374 | -77 | False | b_22084893 | ... | FALSE | b_21282193 | -87.0 | FALSE | b_33334040 | -71.0 | FALSE | b_29623262 | -68.0 | false\n |
4 | 118746 | NaN | u_30099170 | 2017/9/2 20:40 | 116.861989 | 40.326858 | b_19882704 | -77 | False | b_2241462 | ... | FALSE | b_585687 | -57.0 | FALSE | b_37967785 | -62.0 | FALSE | b_29284311 | -42.0 | false\n |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
495 | 119237 | NaN | u_30257349 | 2017/9/5 17:50 | 120.745494 | 30.815596 | b_19907372 | -91 | False | b_40767122 | ... | FALSE | b_56692079 | -69.0 | FALSE | b_2069544 | -89.0 | FALSE | b_36484904 | -59.0 | false\n |
496 | 119238 | NaN | u_30257371 | 2017/9/2 16:50 | 120.694463 | 31.953709 | b_39339718 | -46 | False | b_52367573 | ... | FALSE | b_21638417 | -63.0 | FALSE | b_21638416 | -61.0 | FALSE | b_19054839 | -62.0 | false\n |
497 | 119239 | NaN | u_30257834 | 2017/9/6 16:00 | 119.192835 | 32.424525 | b_32449092 | -36 | False | b_28588685 | ... | FALSE | b_6899715 | -69.0 | FALSE | b_31951717 | -67.0 | FALSE | b_40924426 | -75.0 | false\n |
498 | 119240 | NaN | u_30257834 | 2017/9/6 16:00 | 119.192796 | 32.424623 | b_49195203 | -69 | False | b_57271624 | ... | FALSE | b_17365028 | -60.0 | FALSE | b_28588685 | -54.0 | FALSE | b_28870484 | -56.0 | false\n |
499 | 119241 | NaN | u_30258138 | 2017/9/1 17:50 | 114.474138 | 31.080863 | b_4337554 | -85 | False | b_12683769 | ... | FALSE | b_25093262 | -73.0 | FALSE | b_11542050 | -86.0 | FALSE | b_36907324 | -83.0 | false\n |
500 rows × 36 columns
将时间的string转化成python datetime
使用pandas的自带api pd.to_datetime()
train['time_stamp'] = pd.to_datetime(pd.Series(train['time_stamp']))
tests['time_stamp'] = pd.to_datetime(pd.Series(tests['time_stamp']))
train['time_stamp']
0 2017-08-06 21:20:00
1 2017-08-06 21:20:00
2 2017-08-02 13:10:00
3 2017-08-13 12:30:00
4 2017-08-25 19:50:00
...
495 2017-08-24 20:10:00
496 2017-08-03 18:10:00
497 2017-08-12 18:20:00
498 2017-08-06 21:00:00
499 2017-08-24 20:10:00
Name: time_stamp, Length: 500, dtype: datetime64[ns]
print(train.info())
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 500 entries, 0 to 499
Data columns (total 35 columns):
user_id 500 non-null object
shop_id 500 non-null object
time_stamp 500 non-null datetime64[ns]
longitude 500 non-null float64
latitude 500 non-null float64
wifi_id1 500 non-null object
wifi_strong1 500 non-null int64
con_sta1 500 non-null bool
wifi_id2 500 non-null object
wifi_strong2 500 non-null int64
con_sta2 500 non-null object
wifi_id3 499 non-null object
wifi_strong3 499 non-null float64
con_sta3 499 non-null object
wifi_id4 497 non-null object
wifi_strong4 497 non-null float64
con_sta4 497 non-null object
wifi_id5 496 non-null object
wifi_strong5 496 non-null float64
con_sta5 496 non-null object
wifi_id6 495 non-null object
wifi_strong6 495 non-null float64
con_sta6 495 non-null object
wifi_id7 494 non-null object
wifi_strong7 494 non-null float64
con_sta7 494 non-null object
wifi_id8 486 non-null object
wifi_strong8 486 non-null float64
con_sta8 486 non-null object
wifi_id9 478 non-null object
wifi_strong9 478 non-null float64
con_sta9 478 non-null object
wifi_id10 467 non-null object
wifi_strong10 467 non-null float64
con_sta10 467 non-null object
dtypes: bool(1), datetime64[ns](1), float64(10), int64(2), object(21)
memory usage: 133.4+ KB
None
将时间datetime细分为year,month,weekday,time
train['Year'] = train['time_stamp'].apply(lambda x:x.year)
train['Month'] = train['time_stamp'].apply(lambda x: x.month)
train['weekday'] = train['time_stamp'].apply(lambda x: x.weekday())
train['time'] = train['time_stamp'].dt.time
tests['Year'] = tests['time_stamp'].apply(lambda x: x.year)
tests['Month'] = tests['time_stamp'].apply(lambda x: x.month)
tests['weekday'] = tests['time_stamp'].dt.dayofweek
tests['time'] = tests['time_stamp'].dt.time
train['Year']
0 2017
1 2017
2 2017
3 2017
4 2017
...
495 2017
496 2017
497 2017
498 2017
499 2017
Name: Year, Length: 500, dtype: int64
train['Month']
0 8
1 8
2 8
3 8
4 8
..
495 8
496 8
497 8
498 8
499 8
Name: Month, Length: 500, dtype: int64
train['weekday']
0 6
1 6
2 2
3 6
4 4
..
495 3
496 3
497 5
498 6
499 3
Name: weekday, Length: 500, dtype: int64
train['time']
0 21:20:00
1 21:20:00
2 13:10:00
3 12:30:00
4 19:50:00
...
495 20:10:00
496 18:10:00
497 18:20:00
498 21:00:00
499 20:10:00
Name: time, Length: 500, dtype: object
删除’time_stamp’以节约内存
train = train.drop('time_stamp', axis=1)
train = train.dropna(axis=0)
tests = tests.drop('time_stamp', axis=1)
#pad/ffill:用前一个非缺失值去填充该缺失值
backfill/bfill:用下一个非缺失值填充该缺失值
None:指定一个值去替换缺失值
将类别信息用one_hot编码
for f in train.columns:
if train[f].dtype=='object':
if f != 'shop_id':
print(f)
lbl = preprocessing.LabelEncoder()
train[f] = lbl.fit_transform(list(train[f].values))
user_id
wifi_id1
wifi_id2
con_sta2
wifi_id3
con_sta3
wifi_id4
con_sta4
wifi_id5
con_sta5
wifi_id6
con_sta6
wifi_id7
con_sta7
wifi_id8
con_sta8
wifi_id9
con_sta9
wifi_id10
con_sta10
time
train
user_id | shop_id | longitude | latitude | wifi_id1 | wifi_strong1 | con_sta1 | wifi_id2 | wifi_strong2 | con_sta2 | ... | wifi_id9 | wifi_strong9 | con_sta9 | wifi_id10 | wifi_strong10 | con_sta10 | Year | Month | weekday | time | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 108 | s_2871718 | 122.308291 | 32.088040 | 411 | -67 | False | 272 | -86 | 0 | ... | 385 | -77.0 | 0 | 402 | -55.0 | 0 | 2017 | 8 | 6 | 73 |
1 | 108 | s_2871718 | 122.308162 | 32.087970 | 411 | -67 | False | 374 | -73 | 0 | ... | 195 | -89.0 | 0 | 402 | -57.0 | 0 | 2017 | 8 | 6 | 73 |
2 | 2 | s_181637 | 117.365255 | 40.638214 | 434 | -78 | False | 128 | -52 | 0 | ... | 232 | -56.0 | 0 | 253 | -84.0 | 0 | 2017 | 8 | 2 | 24 |
3 | 4 | s_609470 | 121.134451 | 31.197416 | 143 | -73 | False | 147 | -64 | 0 | ... | 144 | -72.0 | 0 | 176 | -80.0 | 0 | 2017 | 8 | 6 | 20 |
4 | 23 | s_3816766 | 122.255867 | 31.351320 | 259 | -66 | False | 250 | -58 | 0 | ... | 91 | -72.0 | 0 | 99 | -91.0 | 0 | 2017 | 8 | 4 | 64 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
495 | 306 | s_398021 | 121.731091 | 32.602940 | 275 | -36 | True | 268 | -53 | 0 | ... | 98 | -77.0 | 0 | 367 | -63.0 | 0 | 2017 | 8 | 3 | 66 |
496 | 307 | s_386382 | 111.341364 | 31.216452 | 31 | -62 | False | 316 | -70 | 0 | ... | 323 | -66.0 | 0 | 391 | -49.0 | 0 | 2017 | 8 | 3 | 54 |
497 | 308 | s_322471 | 122.596036 | 31.581866 | 382 | -46 | False | 248 | -73 | 0 | ... | 171 | -58.0 | 0 | 241 | -73.0 | 0 | 2017 | 8 | 5 | 55 |
498 | 309 | s_390053 | 121.365752 | 32.316147 | 433 | -51 | False | 298 | -65 | 0 | ... | 188 | -82.0 | 0 | 111 | -53.0 | 1 | 2017 | 8 | 6 | 71 |
499 | 306 | s_398021 | 121.731091 | 32.602940 | 275 | -36 | True | 268 | -53 | 0 | ... | 98 | -77.0 | 0 | 367 | -63.0 | 0 | 2017 | 8 | 3 | 66 |
467 rows × 38 columns
对测试数据集应用同样的方法
for f in tests.columns:
if tests[f].dtype == 'object':
print(f)
lbl = preprocessing.LabelEncoder()
lbl.fit(list(tests[f].values))
tests[f] = lbl.transform(list(tests[f].values))
选取需要的特征
feature_columns_to_use = ['Year', 'Month', 'weekday',
'time', 'longitude', 'latitude',
'wifi_id1', 'wifi_strong1', 'con_sta1',
'wifi_id2', 'wifi_strong2', 'con_sta2',
'wifi_id3', 'wifi_strong3', 'con_sta3',
'wifi_id4', 'wifi_strong4', 'con_sta4',
'wifi_id5', 'wifi_strong5', 'con_sta5',
'wifi_id6', 'wifi_strong6', 'con_sta6',
'wifi_id7', 'wifi_strong7', 'con_sta7',
'wifi_id8', 'wifi_strong8', 'con_sta8',
'wifi_id9', 'wifi_strong9', 'con_sta9',
'wifi_id10', 'wifi_strong10', 'con_sta10',]
big_train = train[feature_columns_to_use]
big_test = tests[feature_columns_to_use]
train_X = big_train.to_numpy()
test_X = big_test.to_numpy()
big_train
Year | Month | weekday | time | longitude | latitude | wifi_id1 | wifi_strong1 | con_sta1 | wifi_id2 | ... | con_sta7 | wifi_id8 | wifi_strong8 | con_sta8 | wifi_id9 | wifi_strong9 | con_sta9 | wifi_id10 | wifi_strong10 | con_sta10 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2017 | 8 | 6 | 73 | 122.308291 | 32.088040 | 411 | -67 | False | 272 | ... | 0 | 386 | -89.0 | 0 | 385 | -77.0 | 0 | 402 | -55.0 | 0 |
1 | 2017 | 8 | 6 | 73 | 122.308162 | 32.087970 | 411 | -67 | False | 374 | ... | 0 | 405 | -57.0 | 0 | 195 | -89.0 | 0 | 402 | -57.0 | 0 |
2 | 2017 | 8 | 2 | 24 | 117.365255 | 40.638214 | 434 | -78 | False | 128 | ... | 0 | 433 | -74.0 | 0 | 232 | -56.0 | 0 | 253 | -84.0 | 0 |
3 | 2017 | 8 | 6 | 20 | 121.134451 | 31.197416 | 143 | -73 | False | 147 | ... | 0 | 184 | -60.0 | 0 | 144 | -72.0 | 0 | 176 | -80.0 | 0 |
4 | 2017 | 8 | 4 | 64 | 122.255867 | 31.351320 | 259 | -66 | False | 250 | ... | 0 | 408 | -80.0 | 0 | 91 | -72.0 | 0 | 99 | -91.0 | 0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
495 | 2017 | 8 | 3 | 66 | 121.731091 | 32.602940 | 275 | -36 | True | 268 | ... | 0 | 189 | -79.0 | 0 | 98 | -77.0 | 0 | 367 | -63.0 | 0 |
496 | 2017 | 8 | 3 | 54 | 111.341364 | 31.216452 | 31 | -62 | False | 316 | ... | 0 | 26 | -62.0 | 0 | 323 | -66.0 | 0 | 391 | -49.0 | 0 |
497 | 2017 | 8 | 5 | 55 | 122.596036 | 31.581866 | 382 | -46 | False | 248 | ... | 0 | 366 | -45.0 | 0 | 171 | -58.0 | 0 | 241 | -73.0 | 0 |
498 | 2017 | 8 | 6 | 71 | 121.365752 | 32.316147 | 433 | -51 | False | 298 | ... | 0 | 151 | -80.0 | 0 | 188 | -82.0 | 0 | 111 | -53.0 | 1 |
499 | 2017 | 8 | 3 | 66 | 121.731091 | 32.602940 | 275 | -36 | True | 268 | ... | 0 | 189 | -79.0 | 0 | 98 | -77.0 | 0 | 367 | -63.0 | 0 |
467 rows × 36 columns
train_X[0]
array([2017, 8, 6, 73, 122.308291, 32.08804, 411, -67, False, 272, -86, 0,
160, -90.0, 0, 403, -55.0, 0, 446, -90.0, 0, 208, -74.0, 0, 405,
-68.0, 0, 386, -89.0, 0, 385, -77.0, 0, 402, -55.0, 0],
dtype=object)
train_y = train['shop_id']
gbm = xgb.XGBClassifier(silent=1, max_depth=10,
n_estimators=1000, learning_rate=0.05)
gbm.fit(train_X, train_y)
predictions = gbm.predict(test_X)
提交预测
submission = pd.DataFrame({'row_id': tests['row_id'],
'shop_id': predictions})
print(submission)
submission.to_csv("submission.csv",index=False)
完整的代码是这样。
import pandas as pd
import xgboost as xgb
from sklearn import preprocessing
train = pd.read_csv('train.csv')
tests = pd.read_csv('test.csv')
train['time_stamp'] = pd.to_datetime(pd.Series(train['time_stamp']))
tests['time_stamp'] = pd.to_datetime(pd.Series(tests['time_stamp']))
print(train.info())
train['Year'] = train['time_stamp'].apply(lambda x:x.year)
train['Month'] = train['time_stamp'].apply(lambda x: x.month)
train['weekday'] = train['time_stamp'].apply(lambda x: x.weekday())
train['time'] = train['time_stamp'].dt.time
tests['Year'] = tests['time_stamp'].apply(lambda x: x.year)
tests['Month'] = tests['time_stamp'].apply(lambda x: x.month)
tests['weekday'] = tests['time_stamp'].dt.dayofweek
tests['time'] = tests['time_stamp'].dt.time
train = train.drop('time_stamp', axis=1)
train = train.dropna(axis=0)
tests = tests.drop('time_stamp', axis=1)
tests = tests.fillna(method='pad')
for f in train.columns:
if train[f].dtype=='object':
if f != 'shop_id':
print(f)
lbl = preprocessing.LabelEncoder()
train[f] = lbl.fit_transform(list(train[f].values))
for f in tests.columns:
if tests[f].dtype == 'object':
print(f)
lbl = preprocessing.LabelEncoder()
lbl.fit(list(tests[f].values))
tests[f] = lbl.transform(list(tests[f].values))
feature_columns_to_use = ['Year', 'Month', 'weekday',
'time', 'longitude', 'latitude',
'wifi_id1', 'wifi_strong1', 'con_sta1',
'wifi_id2', 'wifi_strong2', 'con_sta2',
'wifi_id3', 'wifi_strong3', 'con_sta3',
'wifi_id4', 'wifi_strong4', 'con_sta4',
'wifi_id5', 'wifi_strong5', 'con_sta5',
'wifi_id6', 'wifi_strong6', 'con_sta6',
'wifi_id7', 'wifi_strong7', 'con_sta7',
'wifi_id8', 'wifi_strong8', 'con_sta8',
'wifi_id9', 'wifi_strong9', 'con_sta9',
'wifi_id10', 'wifi_strong10', 'con_sta10',]
big_train = train[feature_columns_to_use]
big_test = tests[feature_columns_to_use]
train_X = big_train.to_numpy()
test_X = big_test.to_numpy()
train_y = train['shop_id']
gbm = xgb.XGBClassifier(silent=1, max_depth=10,
n_estimators=1000, learning_rate=0.05)
gbm.fit(train_X, train_y)
predictions = gbm.predict(test_X)
submission = pd.DataFrame({'row_id': tests['row_id'],
'shop_id': predictions})
print(submission)
submission.to_csv("submission.csv",index=False)