python利用SVM区分二维样本

具体包含一些SVM模型训练,计算POD,FAR,CSI等指标过程,之后有空详细解释下……
代码如下:

import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns;sns.set()
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split,cross_val_score
from sklearn import metrics
from sklearn.metrics import roc_curve
def loaddata3(file):
    xall=[]
    yall=[]
    flagall=[]
    x1=[]
    x2=[]
    x3=[]
    FPx=[]
    FPy=[]
    fileIn=open(file)
    for line in fileIn.readlines():
        lineArr=line.strip().split()
       # if (float(lineArr[0])<0.1):
        #    continue
       # if (float(lineArr[0])<20 and float(lineArr[1])<20):
       #     continue
        xall.append(float(lineArr[0]))
        yall.append(float(lineArr[1]))
        flagall.append(int(lineArr[2]))
       
        x1.append(float(lineArr[0]))
        x2.append(float(lineArr[1]))
        x3.append(int(lineArr[2])) 
    return x1,x2,x3,FPx,FPy,xall,yall,flagall

Station = 'guangzhou'
Out_path = './fig/' + Station +'_POD.txt'
f = open(Out_path,'w')
f2 = open('./pod_far_curve.txt','w')

fig_name =[]
black_last=0
bnum = []
pod_arr=[]
far_arr=[]

Fig_path = './fig/' + Station + '_DT.jpg'
infile = 'E:/MyPaper/Paper_next/py/input_bg_9y/ML_59287.txt'
x1,x2,x3,FPx,FPy,xalltmp,yalltmp,flagalltmp=loaddata3(infile)
xalltmp = np.array(xalltmp)
yalltmp = np.array(yalltmp)
#falgalltmp = np.array(flagalltmp)Station = 'guangzhou'
Out_path = './fig/' + Station +'_POD.txt'
f = open(Out_path,'w')
f2 = open('./pod_far_curve.txt','w')

fig_name =[]
black_last=0
bnum = []
pod_arr=[]
far_arr=[]

Fig_path = './fig/' + Station + '_DT.jpg'
infile = 'E:/MyPaper/Paper_next/py/input_bg_9y/ML_59287.txt'
x1,x2,x3,FPx,FPy,xalltmp,yalltmp,flagalltmp=loaddata3(infile)
xalltmp = np.array(xalltmp)
yalltmp = np.array(yalltmp)
falgalltmp = np.array(flagalltmp)

infile2 = 'E:/MyPaper/Paper_next/py/input_bg_9y/ML_59287_1y.txt'
x1tmp,x2tmp,x3tmp,FPx,FPy,xallpre,yallpre,flagall=loaddata3(infile2)
xall = np.array(xallpre)
yall = np.array(yallpre)
falgall = np.array(flagall)

infile2 = 'E:/MyPaper/Paper_next/py/input_bg_9y/ML_59287_1y.txt'
x1tmp,x2tmp,x3tmp,FPx,FPy,xallpre,yallpre,flagall=loaddata3(infile2)
xall = np.array(xallpre)
yall = np.array(yallpre)
falgall = np.array(flagall)

X=[list(item) for item in zip(x1,x2)]
y =x3
X = np.array(X)
y = np.array(y)
#print(y)
x_1 = []
x_0 = []
y_1 = []
y_0 = []
for i in range(len(x1)):
    if (x3[i] ==1):
        x_1.append(x1[i])
        y_1.append(x2[i])
    else:
        x_0.append(x1[i])
        y_0.append(x2[i])
x_1 = np.array(x_1)
y_1 = np.array(y_1)
x_0 = np.array(x_0)
y_0 = np.array(y_0)


clf = SVC(C=0.001,kernel='linear',gamma=0.1).fit(X, y)
score = clf.score(X,y)
Acc = cross_val_score(clf,X,y,cv =5,scoring='accuracy')
Recall= cross_val_score(clf,X,y,cv =5,scoring='recall')
W = clf.coef_[0]
l = clf.intercept_
print( 'cross acc:',np.mean(Acc),Acc)
print( 'recall acc:',np.mean(Recall),Recall)

k =-W[0]/W[1]
b =-l[0]/W[1]
print ('b= ',-l[0]/W[1], ' k= ', -W[0]/W[1])
print ('score=',score)

#calculate FAR POD
TP = 0
TN = 0
FP = 0
FN = 0
a1=[]
b1=[]
a2=[]
b2=[]
a3=[]
b3=[]
a4=[]
b4=[]
lxx =np.zeros(10000)
for i in range(10000):
    lxx[i] = i *0.1
lyy =k*lxx+b

pred_label = []
for i in range(len(xall)):

    y_line = k*xall[i]+b
    if(clf.predict([[xall[i],yall[i]]])==1 and flagall[i]==1):
       a1.append(xall[i])
       b1.append(yall[i])
       TP = TP + 1
    elif (clf.predict([[xall[i],yall[i]]])==0 and flagall[i]==1):
       a2.append(xall[i])
       b2.append(yall[i])
       FN = FN + 1
    elif (clf.predict([[xall[i],yall[i]]])==1 and flagall[i]==0):
       a3.append(xall[i])
       b3.append(yall[i])
       FP = FP + 1
    elif (clf.predict([[xall[i],yall[i]]])==0 and flagall[i]==0):
       a4.append(xall[i])
       b4.append(yall[i])
       TN = TN + 1
    pred_label.append(clf.predict([[xall[i],yall[i]]]))
           
         
POD = TP / (TP + FN)
#FAR = FP / (FP + TN)
FAR = FP / (FP + TP)
CSI = TP / (TP+FP+FN)
ax = plt.axes()
ax.set_facecolor("white")
plt.scatter(a4,b4,label="TN",s=9,c='coral')
plt.scatter(a3,b3,label="FP",s=9,c='coral',marker='x')
plt.scatter(a1,b1,label="TP",s=9,c='deepskyblue')
plt.scatter(a2,b2,label="FN",s=9,c='deepskyblue',marker='x')
plt.xlim(0,230)
plt.ylim(0,120)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.xlabel('Daily Accumulative Rainfall (mm)',fontdict={'size':10})
plt.ylabel('3h Accumulative Rainfall (mm)',fontdict={'size':10})
plt.legend(fontsize=10,facecolor='white')
ax.spines['bottom'].set_color('black')
ax.spines['top'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['right'].set_color('black')
#plt.plot(lxx,lyy,color='black',linewidth = 2,linestyle='--')
plt.plot(lxx,lyy,color='black',linewidth = 2,linestyle='--')
print ('TP:' ,TP,' TN ',TN,' FP:' ,FP,' FN: ',FN, 'POD = ', POD, 'FAR = ',FAR, 'CSI =', CSI)
#plt.vlines(46, 0, 120, colors = "black", linestyles = "dashed")
#plt.savefig(Fig_path,dpi=300)
plt.show()


list3=list(zip(xall,yall))
fpr,tpr, thresholds = roc_curve(flagall,clf.decision_function(np.array(list3)))
#fpr, tpr, threshold = metrics.roc_curve(flagall, np.array(pred_label))
roc_auc = metrics.auc(fpr, tpr)
plt.figure(figsize=(6,6))
plt.title('Validation ROC')
plt.plot(fpr, tpr, 'b', label = 'Val AUC = %0.3f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

宇天y

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值