1.支持向量机即SVM原理
支持向量机即SVM(Support Vector Machine) ,是一种监督学习算法,属于分类的范畴。它的原理就是求出“保证距离最近的点,距离它们最远的线”。例如图中要能划分出蓝色和红色球,并且最近的蓝色与和红色要距离线最远。
2.支持向量机算法
寻找最大分类间距
转而通过拉格朗日函数求优化的问题
数据可以通过画一条直线就可以将它们完全分开,这组数据叫线性可分(linearly separable)数据,而这条分隔直线称为分隔超平面(separating hyperplane)。如果数据集上升到1024维呢?那么需要1023维来分隔数据集,也就说需要N-1维的对象来分隔,这个对象叫做超平面(hyperlane),也就是分类的决策边界。
3.svm的应用
策略:利用sma、wma、mom指标、收盘价来训练svm,其中今天收盘价大于昨天收盘价买入股票,否则卖出股票。
以下案例是svm在股票预测中的应用。下面程序只是一个核心的demo,后续内容会在这demo中进行扩展。
# -*- coding: utf-8 -*-
"""
Created on Tue May 19 19:06:26 2020
@author: 觉醒2020
"""
import csv
import talib
import numpy as np
from sklearn import svm
"""
基于SVM的机器学习策略
步骤:
1.数据采集
2.训练
3.预测
"""
#数据采集
class FileManager():
def readInfo(fieldOption):
info=[]
datapath=".\\datas\\test\\000001.XSHE"
with open(datapath, 'r') as f:
reader = csv.reader(f)
i=0
for r in reader:
if i>0 :
info.append(r[fieldOption])
i=i+1
return info
# SVM训练分类器
def SVM_train():
print("开始训练")
close_price = FileManager.readInfo(2)
x_train = [] # 特征
y_train = [] # 标记
np_close_price=np.array(close_price,dtype='f8')
sma_data=talib.SMA(np_close_price,15)
wma_data=talib.WMA(np_close_price,15)
mom_data=talib.MOM(np_close_price,15)
#例如:训练前30天以前的数据
for i in range(-31,-len(sma_data)+30,-1):
sma_data_tmp=sma_data[i]
wma_data_tmp=wma_data[i]
mom_data_tmp=mom_data[i]
features = []
features.append(sma_data_tmp)
features.append(wma_data_tmp)
features.append(mom_data_tmp)
label = False # 标记为跌(False)
if np_close_price[-i] > np_close_price[-i-1]:
label=True
x_train.append(features)
y_train.append(label)
svm_module = svm.SVC()
svm_module.fit(x_train, y_train) # 训练分类器
print("训练结束")
return svm_module
#利用svm进行股票交易
def svm_trade():
close_price = FileManager.readInfo(2)
np_close_price=np.array(close_price,dtype='f8')
sma_data=talib.SMA(np_close_price,15)
wma_data=talib.WMA(np_close_price,15)
mom_data=talib.MOM(np_close_price,15)
svm_module=SVM_train()
for i in range(-1,-30,-1):
sma_data_tmp=sma_data[i]
wma_data_tmp=wma_data[i]
mom_data_tmp=mom_data[i]
features = []
x = []
features.append(sma_data_tmp)
features.append(wma_data_tmp)
features.append(mom_data_tmp)
x.append(features)
flag = svm_module.predict(x) # 预测的涨跌结果
if bool(flag):
print('买入=',flag)
else:
print('卖出=',flag)
if __name__=='__main__':
svm_trade()