具体包含一些SVM模型训练,计算POD,FAR,CSI等指标过程,之后有空详细解释下……
代码如下:
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
import seaborn as sns;sns.set()
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split,cross_val_score
from sklearn import metrics
from sklearn.metrics import roc_curve
def loaddata3(file):
xall=[]
yall=[]
flagall=[]
x1=[]
x2=[]
x3=[]
FPx=[]
FPy=[]
fileIn=open(file)
for line in fileIn.readlines():
lineArr=line.strip().split()
# if (float(lineArr[0])<0.1):
# continue
# if (float(lineArr[0])<20 and float(lineArr[1])<20):
# continue
xall.append(float(lineArr[0]))
yall.append(float(lineArr[1]))
flagall.append(int(lineArr[2]))
x1.append(float(lineArr[0]))
x2.append(float(lineArr[1]))
x3.append(int(lineArr[2]))
return x1,x2,x3,FPx,FPy,xall,yall,flagall
Station = 'guangzhou'
Out_path = './fig/' + Station +'_POD.txt'
f = open(Out_path,'w')
f2 = open('./pod_far_curve.txt','w')
fig_name =[]
black_last=0
bnum = []
pod_arr=[]
far_arr=[]
Fig_path = './fig/' + Station + '_DT.jpg'
infile = 'E:/MyPaper/Paper_next/py/input_bg_9y/ML_59287.txt'
x1,x2,x3,FPx,FPy,xalltmp,yalltmp,flagalltmp=loaddata3(infile)
xalltmp = np.array(xalltmp)
yalltmp = np.array(yalltmp)
#falgalltmp = np.array(flagalltmp)Station = 'guangzhou'
Out_path = './fig/' + Station +'_POD.txt'
f = open(Out_path,'w')
f2 = open('./pod_far_curve.txt','w')
fig_name =[]
black_last=0
bnum = []
pod_arr=[]
far_arr=[]
Fig_path = './fig/' + Station + '_DT.jpg'
infile = 'E:/MyPaper/Paper_next/py/input_bg_9y/ML_59287.txt'
x1,x2,x3,FPx,FPy,xalltmp,yalltmp,flagalltmp=loaddata3(infile)
xalltmp = np.array(xalltmp)
yalltmp = np.array(yalltmp)
falgalltmp = np.array(flagalltmp)
infile2 = 'E:/MyPaper/Paper_next/py/input_bg_9y/ML_59287_1y.txt'
x1tmp,x2tmp,x3tmp,FPx,FPy,xallpre,yallpre,flagall=loaddata3(infile2)
xall = np.array(xallpre)
yall = np.array(yallpre)
falgall = np.array(flagall)
infile2 = 'E:/MyPaper/Paper_next/py/input_bg_9y/ML_59287_1y.txt'
x1tmp,x2tmp,x3tmp,FPx,FPy,xallpre,yallpre,flagall=loaddata3(infile2)
xall = np.array(xallpre)
yall = np.array(yallpre)
falgall = np.array(flagall)
X=[list(item) for item in zip(x1,x2)]
y =x3
X = np.array(X)
y = np.array(y)
#print(y)
x_1 = []
x_0 = []
y_1 = []
y_0 = []
for i in range(len(x1)):
if (x3[i] ==1):
x_1.append(x1[i])
y_1.append(x2[i])
else:
x_0.append(x1[i])
y_0.append(x2[i])
x_1 = np.array(x_1)
y_1 = np.array(y_1)
x_0 = np.array(x_0)
y_0 = np.array(y_0)
clf = SVC(C=0.001,kernel='linear',gamma=0.1).fit(X, y)
score = clf.score(X,y)
Acc = cross_val_score(clf,X,y,cv =5,scoring='accuracy')
Recall= cross_val_score(clf,X,y,cv =5,scoring='recall')
W = clf.coef_[0]
l = clf.intercept_
print( 'cross acc:',np.mean(Acc),Acc)
print( 'recall acc:',np.mean(Recall),Recall)
k =-W[0]/W[1]
b =-l[0]/W[1]
print ('b= ',-l[0]/W[1], ' k= ', -W[0]/W[1])
print ('score=',score)
#calculate FAR POD
TP = 0
TN = 0
FP = 0
FN = 0
a1=[]
b1=[]
a2=[]
b2=[]
a3=[]
b3=[]
a4=[]
b4=[]
lxx =np.zeros(10000)
for i in range(10000):
lxx[i] = i *0.1
lyy =k*lxx+b
pred_label = []
for i in range(len(xall)):
y_line = k*xall[i]+b
if(clf.predict([[xall[i],yall[i]]])==1 and flagall[i]==1):
a1.append(xall[i])
b1.append(yall[i])
TP = TP + 1
elif (clf.predict([[xall[i],yall[i]]])==0 and flagall[i]==1):
a2.append(xall[i])
b2.append(yall[i])
FN = FN + 1
elif (clf.predict([[xall[i],yall[i]]])==1 and flagall[i]==0):
a3.append(xall[i])
b3.append(yall[i])
FP = FP + 1
elif (clf.predict([[xall[i],yall[i]]])==0 and flagall[i]==0):
a4.append(xall[i])
b4.append(yall[i])
TN = TN + 1
pred_label.append(clf.predict([[xall[i],yall[i]]]))
POD = TP / (TP + FN)
#FAR = FP / (FP + TN)
FAR = FP / (FP + TP)
CSI = TP / (TP+FP+FN)
ax = plt.axes()
ax.set_facecolor("white")
plt.scatter(a4,b4,label="TN",s=9,c='coral')
plt.scatter(a3,b3,label="FP",s=9,c='coral',marker='x')
plt.scatter(a1,b1,label="TP",s=9,c='deepskyblue')
plt.scatter(a2,b2,label="FN",s=9,c='deepskyblue',marker='x')
plt.xlim(0,230)
plt.ylim(0,120)
plt.xticks(fontsize=8)
plt.yticks(fontsize=8)
plt.xlabel('Daily Accumulative Rainfall (mm)',fontdict={'size':10})
plt.ylabel('3h Accumulative Rainfall (mm)',fontdict={'size':10})
plt.legend(fontsize=10,facecolor='white')
ax.spines['bottom'].set_color('black')
ax.spines['top'].set_color('black')
ax.spines['left'].set_color('black')
ax.spines['right'].set_color('black')
#plt.plot(lxx,lyy,color='black',linewidth = 2,linestyle='--')
plt.plot(lxx,lyy,color='black',linewidth = 2,linestyle='--')
print ('TP:' ,TP,' TN ',TN,' FP:' ,FP,' FN: ',FN, 'POD = ', POD, 'FAR = ',FAR, 'CSI =', CSI)
#plt.vlines(46, 0, 120, colors = "black", linestyles = "dashed")
#plt.savefig(Fig_path,dpi=300)
plt.show()
list3=list(zip(xall,yall))
fpr,tpr, thresholds = roc_curve(flagall,clf.decision_function(np.array(list3)))
#fpr, tpr, threshold = metrics.roc_curve(flagall, np.array(pred_label))
roc_auc = metrics.auc(fpr, tpr)
plt.figure(figsize=(6,6))
plt.title('Validation ROC')
plt.plot(fpr, tpr, 'b', label = 'Val AUC = %0.3f' % roc_auc)
plt.legend(loc = 'lower right')
plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.ylabel('True Positive Rate')
plt.xlabel('False Positive Rate')
plt.show()