这一节学习机器学习算法的对比,第三章就九大不同的算法对预测指数涨跌情况进行对比。
dt = pd.DataFrame(columns = label)
for date in tradelist:
stock = get_index_stocks(indexcode,date)
df = get_price(stock, date, date, '1d', ['quote_rate'], skip_paused = False, fq = 'pre', bar_count = 0, is_panel = 1)['quote_rate'].T.fillna(0)
label3 = round(df.mean()[0],3)
label2 = (len(list(df[df[date]>0][date]))-len(list(df[df[date]<0][date])))/len(list(df[date]))
moneydf = get_money_flow_step(stock,date,date,'1d',['net_flow_rate'],None,is_panel=1)['net_flow_rate'].T.fillna(0)
label1 = round(moneydf.mean()[0],3)
dt.loc[date] = [label1,label2,label3]
value = list(get_price(indexcode, startdate, enddate, '1d', ['quote_rate'], skip_paused = False, fq = 'pre', bar_count = 0, is_panel = 1)['quote_rate'])
dt['now up']=value
dt['now label'] = dt['now up'].apply(lambda x:1 if x>0 else -1)
dt['next up']=list(dt['now up'])[1:]+[0]
dt['label']=dt['next up'].apply(lambda x:1 if x>0 else -1)
dt
Out[208]:
money rate % | net up rate % | mean of updown % | now up | now label | next up | label | |
---|---|---|---|---|---|---|---|
20140102 | -6.977 | -0.44 | -0.455 | -0.8688 | -1 | -1.6035 | -1 |
20140103 | -14.628 | -0.58 | -1.197 | -1.6035 | -1 | -1.5062 | -1 |
20140106 | -10.977 | -0.62 | -2.111 | -1.5062 | -1 | -0.1504 | -1 |
20140107 | -2.380 | -0.32 | 0.090 | -0.1504 | -1 | 0.3097 | 1 |
20140108 | -4.171 | -0.04 | -0.175 | 0.3097 | 1 | -0.7557 | -1 |
20140109 | -4.399 | -0.62 | -0.917 | -0.7557 | -1 | -0.3732 | -1 |
20140110 | -8.866 | -0.22 | -0.622 | -0.3732 | -1 | -0.2439 | -1 |
20140113 | -11.486 | -0.02 | -0.250 | -0.2439 | -1 | 0.3126 | 1 |
20140114 | -7.618 | 0.32 | 0.199 | 0.3126 | 1 | -0.7187 | -1 |
20140115 | -12.046 | -0.52 | -0.620 | -0.7187 | -1 | 0.3080 | 1 |
20140116 | -0.555 | 0.04 | 0.262 | 0.3080 | 1 | -1.1669 | -1 |
20140117 | -11.065 | -0.64 | -1.425 | -1.1669 | -1 | -0.7338 | -1 |
20140120 | -10.974 | -0.64 | -0.878 | -0.7338 | -1 | 0.9586 | 1 |
20140121 | -3.121 | 0.80 | 0.927 | 0.9586 | 1 | 2.3827 | 1 |
20140122 | 10.358 | 1.00 | 2.710 | 2.3827 | 1 | -0.9966 | -1 |
20140123 | -10.280 | -0.72 | -0.755 | -0.9966 | -1 | 0.2024 | 1 |
20140124 | 0.982 | 0.48 | 0.592 | 0.2024 | 1 | -1.5445 | -1 |
20140127 | -12.541 | -0.72 | -1.411 | -1.5445 | -1 | 0.6044 | 1 |
20140128 | -5.032 | 0.34 | 0.366 | 0.6044 | 1 | 0.4495 | 1 |
20140129 | -7.350 | 0.08 | 0.095 | 0.4495 | 1 | -1.1472 | -1 |
20140130 | -13.810 | -0.76 | -1.215 | -1.1472 | -1 | -0.1355 | -1 |
20140207 | -9.088 | -0.16 | 0.115 | -0.1355 | -1 | 2.0667 | 1 |
20140210 | 7.470 | 0.96 | 2.693 | 2.0667 | 1 | 1.4934 | 1 |
20140211 | 5.662 | 0.72 | 1.398 | 1.4934 | 1 | -0.0864 | -1 |
20140212 | -5.996 | -0.24 | -0.027 | -0.0864 | -1 | 0.0664 | 1 |
20140213 | -4.087 | -0.30 | -0.516 | 0.0664 | 1 | 0.4073 | 1 |
20140214 | -7.007 | 0.36 | 0.796 | 0.4073 | 1 | 0.2520 | 1 |
20140217 | -8.357 | 0.38 | 0.597 | 0.2520 | 1 | -1.9221 | -1 |
20140218 | -22.680 | -0.84 | -1.890 | -1.9221 | -1 | 1.6476 | 1 |
20140219 | 1.759 | 0.78 | 1.418 | 1.6476 | 1 | -0.8646 | -1 |
... | ... | ... | ... | ... | ... | ... | ... |
20181211 | -5.442 | 0.30 | 0.410 | 0.2894 | 1 | 0.3204 | 1 |
20181212 | -2.328 | 0.60 | 0.376 | 0.3204 | 1 | 1.4272 | 1 |
20181213 | 4.618 | 0.88 | 1.545 | 1.4272 | 1 | -1.4046 | -1 |
20181214 | -9.521 | -0.94 | -1.470 | -1.4046 | -1 | 0.1058 | 1 |
20181217 | -5.410 | 0.46 | 0.399 | 0.1058 | 1 | -1.1574 | -1 |
20181218 | -9.453 | -0.72 | -1.188 | -1.1574 | -1 | -1.1810 | -1 |
20181219 | -12.110 | -0.70 | -1.012 | -1.1810 | -1 | -1.4613 | -1 |
20181220 | -8.132 | -0.64 | -1.128 | -1.4613 | -1 | -1.2324 | -1 |
20181221 | -7.126 | -0.60 | -1.204 | -1.2324 | -1 | 0.1339 | 1 |
20181224 | 0.890 | 0.16 | 0.319 | 0.1339 | 1 | -0.5153 | -1 |
20181225 | -1.902 | -0.46 | -0.801 | -0.5153 | -1 | -0.6867 | -1 |
20181226 | -2.723 | -0.44 | -0.569 | -0.6867 | -1 | -0.2411 | -1 |
20181227 | -4.277 | 0.02 | -0.454 | -0.2411 | -1 | 0.7487 | 1 |
20181228 | 1.770 | 0.48 | 0.564 | 0.7487 | 1 | -1.3217 | -1 |
20190102 | -5.086 | -0.64 | -1.368 | -1.3217 | -1 | 0.2851 | 1 |
20190103 | 0.290 | 0.18 | 0.151 | 0.2851 | 1 | 2.0008 | 1 |
20190104 | 7.130 | 0.90 | 2.067 | 2.0008 | 1 | -0.0141 | -1 |
20190107 | -4.997 | 0.08 | 0.212 | -0.0141 | -1 | -0.3953 | -1 |
20190108 | -2.920 | -0.34 | -0.415 | -0.3953 | -1 | 1.1951 | 1 |
20190109 | 5.368 | 0.70 | 1.068 | 1.1951 | 1 | -0.0372 | -1 |
20190110 | -4.994 | -0.48 | -0.376 | -0.0372 | -1 | 0.9712 | 1 |
20190111 | 0.116 | 0.74 | 0.777 | 0.9712 | 1 | -0.9923 | -1 |
20190114 | -6.029 | -0.62 | -0.862 | -0.9923 | -1 | 2.0262 | 1 |
20190115 | 6.696 | 0.90 | 1.631 | 2.0262 | 1 | 0.1198 | 1 |
20190116 | -3.404 | 0.16 | -0.150 | 0.1198 | 1 | -0.4145 | -1 |
20190117 | -2.421 | -0.38 | -0.432 | -0.4145 | -1 | 1.9405 | 1 |
20190118 | 4.543 | 0.86 | 1.558 | 1.9405 | 1 | 0.6256 | 1 |
20190121 | -0.123 | 0.32 | 0.640 | 0.6256 | 1 | -1.2849 | -1 |
20190122 | -10.820 | -0.82 | -1.257 | -1.2849 | -1 | -0.1742 | -1 |
20190123 | -4.998 | -0.02 | -0.088 | -0.1742 | -1 | 0.0000 | -1 |
1236 rows × 7 columns
KNN
In [256]:
label = ['money rate %','net up rate % ','mean of updown %']
label1 = 'money rate %'
label2 = 'net up rate % '
label3 = 'mean of updown %'
#保留近一年的数据,用于测试,之前数据用于训练
train = dt[:-250]
test = dt[-250:]
X=train[label]
Y=train['label']
X_test=test[label]
Y_test=test['label']
from sklearn.neighbors import KNeighborsClassifier
from sklearn.linear_model import LogisticRegression
from sklearn import tree
from sklearn import svm
from sklearn.naive_bayes import GaussianNB
from sklearn.cluster import KMeans
from sklearn.ensemble import RandomForestClassifier
from sklearn.ensemble import GradientBoostingClassifier
from xgboost import XGBClassifier
dr = pd.DataFrame()
for s in ['tree','KNN','SVM','XGBClass','GBM', 'Random Forest','KMeans','GaussianNB','Logistic']:
if s == 'tree':
model = tree.DecisionTreeClassifier()
elif s == 'KNN':
model=KNeighborsClassifier(n_neighbors=30)
elif s == 'SVM':
model = svm.SVC()
elif s == 'XGBClass':
model=XGBClassifier()
elif s == 'GBM':
model= GradientBoostingClassifier()
elif s == 'Random Forest':
model= RandomForestClassifier()
elif s == 'KMeans':
model = KMeans()
elif s == 'GaussianNB':
model= GaussianNB()
elif s == 'Logistic':
model = LogisticRegression()
model.fit(X, Y)
print('训练时,预测成功率 {}'.format(round(np.mean(model.predict(X)==Y),2)))
print('测试时,预测成功率 {}'.format(round(np.mean(model.predict(X_test)==Y_test),2)))
name = str(s)+' net value'
#净值
test['Forecast'] = list(model.predict(X_test))
test['ref'] = test['next up'].loc[test['Forecast']==1]
test = test.fillna(0)
test['ref'] = test['ref'].apply(lambda x:1+x/100)
from operator import mul
from functools import reduce
test['date'] = test.index
test[name] = test['date'].apply(lambda x:reduce(mul,list(test['ref'])[:list(test['date']).index(x)+1]))
dr[name] = test[name]
dr
训练时,预测成功率 1.0 测试时,预测成功率 0.53 训练时,预测成功率 0.56 测试时,预测成功率 0.55 训练时,预测成功率 0.61 测试时,预测成功率 0.58 训练时,预测成功率 0.72 测试时,预测成功率 0.5 训练时,预测成功率 0.76 测试时,预测成功率 0.51 训练时,预测成功率 0.97 测试时,预测成功率 0.55 训练时,预测成功率 0.1 测试时,预测成功率 0.09 训练时,预测成功率 0.51 测试时,预测成功率 0.5 训练时,预测成功率 0.53 测试时,预测成功率 0.52
Out[256]:
tree net value | KNN net value | SVM net value | XGBClass net value | GBM net value | Random Forest net value | KMeans net value | GaussianNB net value | Logistic net value | |
---|---|---|---|---|---|---|---|---|---|
20180115 | 1.004937 | 1.000000 | 1.004937 | 1.000000 | 1.000000 | 1.000000 | 1.000000 | 1.004937 | 1.004937 |
20180116 | 1.008978 | 1.004021 | 1.008978 | 1.004021 | 1.004021 | 1.004021 | 1.000000 | 1.004937 | 1.008978 |
20180117 | 1.018939 | 1.004021 | 1.018939 | 1.013934 | 1.013934 | 1.013934 | 1.009873 | 1.004937 | 1.008978 |
20180118 | 1.022635 | 1.004021 | 1.018939 | 1.017611 | 1.017611 | 1.013934 | 1.009873 | 1.004937 | 1.012637 |
20180119 | 1.026738 | 1.004021 | 1.018939 | 1.021694 | 1.021694 | 1.013934 | 1.009873 | 1.004937 | 1.016700 |
20180122 | 1.041883 | 1.018831 | 1.018939 | 1.036765 | 1.036765 | 1.028890 | 1.009873 | 1.019761 | 1.031697 |
20180123 | 1.043139 | 1.018831 | 1.018939 | 1.036765 | 1.036765 | 1.028890 | 1.009873 | 1.019761 | 1.031697 |
20180124 | 1.035637 | 1.011504 | 1.011611 | 1.029308 | 1.029308 | 1.021490 | 1.009873 | 1.019761 | 1.024277 |
20180125 | 1.040662 | 1.011504 | 1.016520 | 1.029308 | 1.029308 | 1.021490 | 1.009873 | 1.024709 | 1.029247 |
20180126 | 1.023578 | 0.994899 | 1.016520 | 1.012411 | 1.012411 | 1.021490 | 1.009873 | 1.024709 | 1.029247 |
20180129 | 1.009941 | 0.981644 | 1.016520 | 0.998923 | 0.998923 | 1.007881 | 1.009873 | 1.011057 | 1.015535 |
20180130 | 1.022312 | 0.981644 | 1.028971 | 1.011159 | 1.011159 | 1.007881 | 1.009873 | 1.023441 | 1.027974 |
20180131 | 1.030313 | 0.989327 | 1.037025 | 1.011159 | 1.011159 | 1.015770 | 1.009873 | 1.023441 | 1.036020 |
20180201 | 1.033072 | 0.991977 | 1.039802 | 1.013867 | 1.013867 | 1.018490 | 1.012577 | 1.023441 | 1.036020 |
20180202 | 1.033072 | 1.001960 | 1.050266 | 1.024070 | 1.013867 | 1.018490 | 1.012577 | 1.033741 | 1.046446 |
20180205 | 1.033072 | 0.981851 | 1.029188 | 1.003517 | 0.993518 | 0.998049 | 0.992255 | 1.033741 | 1.025444 |
20180206 | 1.003947 | 0.954169 | 1.000172 | 0.975225 | 0.965508 | 0.969911 | 0.964280 | 1.004597 | 0.996534 |
20180207 | 1.003947 | 0.954169 | 1.000172 | 0.948264 | 0.938816 | 0.969911 | 0.964280 | 0.976824 | 0.968984 |
20180208 | 0.957660 | 0.954169 | 1.000172 | 0.904544 | 0.895532 | 0.969911 | 0.964280 | 0.931787 | 0.924309 |
20180209 | 0.956019 | 0.954169 | 1.000172 | 0.904544 | 0.895532 | 0.969911 | 0.964280 | 0.930190 | 0.922724 |
20180212 | 0.972209 | 0.970328 | 1.000172 | 0.904544 | 0.910697 | 0.986336 | 0.964280 | 0.945943 | 0.938351 |
20180213 | 0.978708 | 0.976815 | 1.006858 | 0.904544 | 0.910697 | 0.992930 | 0.970727 | 0.945943 | 0.938351 |
20180214 | 0.978708 | 0.976815 | 1.006858 | 0.904544 | 0.910697 | 0.992930 | 0.970727 | 0.945943 | 0.957987 |
20180222 | 0.987150 | 0.976815 | 1.015543 | 0.912347 | 0.918553 | 1.001495 | 0.979100 | 0.945943 | 0.957987 |
20180223 | 0.992697 | 0.976815 | 1.021249 | 0.912347 | 0.918553 | 1.007122 | 0.984602 | 0.945943 | 0.957987 |
20180226 | 0.992697 | 0.976815 | 1.021249 | 0.912347 | 0.918553 | 1.007122 | 0.984602 | 0.945943 | 0.957987 |
20180227 | 0.992697 | 0.976815 | 1.004353 | 0.897252 | 0.903356 | 0.990460 | 0.984602 | 0.930292 | 0.942137 |
20180228 | 0.997053 | 0.981101 | 1.008760 | 0.901189 | 0.907320 | 0.994806 | 0.984602 | 0.934374 | 0.946271 |
20180301 | 0.997053 | 0.981101 | 1.008760 | 0.901189 | 0.907320 | 0.985476 | 0.984602 | 0.925612 | 0.937397 |
20180302 | 0.996986 | 0.981101 | 1.008760 | 0.901129 | 0.907259 | 0.985476 | 0.984602 | 0.925550 | 0.937334 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
20181211 | 0.965227 | 1.098993 | 1.191051 | 0.903900 | 0.912769 | 0.986589 | 0.987298 | 0.882477 | 0.913479 |
20181212 | 0.965227 | 1.114678 | 1.208050 | 0.903900 | 0.925796 | 1.000669 | 0.987298 | 0.882477 | 0.913479 |
20181213 | 0.965227 | 1.114678 | 1.208050 | 0.891203 | 0.925796 | 1.000669 | 0.973431 | 0.882477 | 0.913479 |
20181214 | 0.965227 | 1.115857 | 1.209328 | 0.892146 | 0.926776 | 1.001728 | 0.973431 | 0.883411 | 0.914446 |
20181217 | 0.965227 | 1.115857 | 1.209328 | 0.881821 | 0.916049 | 0.990134 | 0.973431 | 0.883411 | 0.914446 |
20181218 | 0.965227 | 1.102679 | 1.195046 | 0.871406 | 0.905231 | 0.990134 | 0.973431 | 0.872978 | 0.903646 |
20181219 | 0.951122 | 1.086565 | 1.177583 | 0.871406 | 0.892003 | 0.990134 | 0.973431 | 0.860221 | 0.890441 |
20181220 | 0.951122 | 1.086565 | 1.163070 | 0.871406 | 0.892003 | 0.990134 | 0.973431 | 0.849620 | 0.879468 |
20181221 | 0.952395 | 1.086565 | 1.163070 | 0.871406 | 0.892003 | 0.990134 | 0.973431 | 0.850757 | 0.880645 |
20181224 | 0.947488 | 1.086565 | 1.163070 | 0.866916 | 0.887406 | 0.990134 | 0.973431 | 0.850757 | 0.876107 |
20181225 | 0.940981 | 1.086565 | 1.155083 | 0.866916 | 0.881312 | 0.990134 | 0.973431 | 0.844915 | 0.870091 |
20181226 | 0.938713 | 1.083946 | 1.152299 | 0.866916 | 0.881312 | 0.987747 | 0.973431 | 0.842878 | 0.867993 |
20181227 | 0.938713 | 1.083946 | 1.152299 | 0.866916 | 0.881312 | 0.987747 | 0.973431 | 0.849189 | 0.874492 |
20181228 | 0.938713 | 1.069619 | 1.137069 | 0.866916 | 0.881312 | 0.987747 | 0.960565 | 0.849189 | 0.874492 |
20190102 | 0.938713 | 1.072669 | 1.140310 | 0.866916 | 0.881312 | 0.987747 | 0.960565 | 0.851610 | 0.876985 |
20190103 | 0.957494 | 1.072669 | 1.140310 | 0.884261 | 0.898946 | 1.007510 | 0.960565 | 0.851610 | 0.894532 |
20190104 | 0.957359 | 1.072517 | 1.140150 | 0.884137 | 0.898819 | 1.007368 | 0.960565 | 0.851610 | 0.894532 |
20190107 | 0.957359 | 1.072517 | 1.140150 | 0.884137 | 0.898819 | 1.007368 | 0.960565 | 0.848243 | 0.890996 |
20190108 | 0.968801 | 1.085335 | 1.153776 | 0.894703 | 0.909561 | 1.019407 | 0.960565 | 0.858381 | 0.901644 |
20190109 | 0.968801 | 1.085335 | 1.153776 | 0.894703 | 0.909561 | 1.019407 | 0.960208 | 0.858381 | 0.901644 |
20190110 | 0.978210 | 1.095876 | 1.153776 | 0.903392 | 0.909561 | 1.019407 | 0.960208 | 0.866717 | 0.910401 |
20190111 | 0.978210 | 1.095876 | 1.153776 | 0.903392 | 0.909561 | 1.009291 | 0.960208 | 0.866717 | 0.910401 |
20190114 | 0.998030 | 1.118081 | 1.153776 | 0.921697 | 0.927990 | 1.029741 | 0.960208 | 0.884279 | 0.928847 |
20190115 | 0.999226 | 1.119420 | 1.155158 | 0.922801 | 0.929102 | 1.029741 | 0.960208 | 0.884279 | 0.928847 |
20190116 | 0.995084 | 1.114780 | 1.150370 | 0.918976 | 0.925251 | 1.029741 | 0.960208 | 0.880613 | 0.924997 |
20190117 | 1.014394 | 1.136412 | 1.172693 | 0.936809 | 0.943205 | 1.049723 | 0.960208 | 0.897702 | 0.942947 |
20190118 | 1.014394 | 1.136412 | 1.172693 | 0.942669 | 0.943205 | 1.049723 | 0.966215 | 0.897702 | 0.942947 |
20190121 | 1.001360 | 1.136412 | 1.172693 | 0.930557 | 0.931086 | 1.036236 | 0.966215 | 0.897702 | 0.930831 |
20190122 | 0.999615 | 1.134433 | 1.170650 | 0.930557 | 0.929464 | 1.034430 | 0.966215 | 0.896138 | 0.929209 |
20190123 | 0.999615 | 1.134433 | 1.170650 | 0.930557 | 0.929464 | 1.034430 | 0.966215 | 0.896138 | 0.929209 |
250 rows × 9 columns
In [257]:
fig = plt.figure()
axes = fig.add_axes([0.1, 0.1, 1, 1.382]) #插入面板
color = ['tomato','green','darkorchid','lightskyblue','y','gold','deeppink','lightgoldenrodyellow','red']
t = list(dr.columns)
for s in t:
g = t.index(s)
x1_list=list(dr[s])
y=np.array(x1_list)
x=np.array(range(0,len(x1_list)))
axes.plot(x, y , color = color[g])
axes.set_xlabel('Time',fontsize=15)
axes.set_ylabel('net value',fontsize=15)
axes.set_title('AI return ',fontsize=20)
axes.legend(t)
#设置X轴
mtradelist = list(test['date'])
numlist=[]
for s in list(range(0,len(mtradelist),60)):
numlist.append(mtradelist[s])
axes.set_xticks(list(range(0,len(mtradelist),60)))
axes.set_xticklabels(numlist, fontsize=10)
查看以上策略详细请 到 supermind量化交易官网查看:机器算法运用--预测指数涨跌的算法对比 附源代码