1 #-*- coding: utf-8 -*-
2 """
3 Created on Mon Dec 2 14:49:59 20184
5 @author: zhen6 """
7
8 importmatplotlib.pyplot as plt9 importnumpy as np10 importpandas as pd11 from datetime importdatetime12
13 def normal(a): #最大值最小值归一化
14 return (a - np.min(a)) / (np.max(a) - np.min(a)+0.000001)15
16 def normalization(x): #np.std:计算矩阵的标准差(方差的算术平方根)
17 return (x - np.mean(x)) /np.std(x)18
19 defcorrcoef(a,b):20 corrc = np.corrcoef(a,b) #计算皮尔逊相关系数,用于度量两个变量之间的相关性,其值介于-1到1之间
21 corrc = corrc[0,1]22 return (16 * ((1 - corrc) / (1 + corrc)) ** 1) #** 表示乘方
23
24 startTimeStamp = datetime.now() #获取当前时间
25 #加载数据
26 filename = 'C:/Users/zhen/.spyder-py3/sh000300_2017.csv'
27 #获取第一,二列的数据
28 all_date = pd.read_csv(filename,usecols=[0, 1, 3], dtype = 'str')29 all_date =np.array(all_date)30 data =all_date[:, 0]31 times = all_date[:, 1]32
33 data_points = pd.read_csv(filename,usecols=[3])34 data_points =np.array(data_points)35 data_points = data_points[:,0] #数据
36
37 topk = 10 #只显示top-10
38 baselen = 100
39 basebegin = 361
40 basedata = data[basebegin]+' '+times[basebegin]+'~'+data[basebegin+baselen-1]+' '+times[basebegin+baselen-1]41 base = data_points[basebegin:basebegin+baselen]#一天的数据是240个点
42 length = len(data_points) #数据长度
43
44 #分割片段
45 subseries =[]46 dateseries =[]47 for j inrange(0,length):48 if (j < (basebegin - baselen) or j > (basebegin + baselen - 1)) and j
51
52 #片段搜索
53 listdistance =[]54 for i inrange(0, len(subseries)):55 tt =np.array(subseries[i])56 distance =corrcoef(base, tt)57 listdistance.append(distance)58
59 #排序
60 index = np.argsort(listdistance,kind='quicksort') #排序,返回排序后的索引序列
61
62 #显示,要匹配的数据
63 plt.figure(0)64 plt.plot((base),label = basedata, linewidth='2')65 plt.legend(loc='upper left')66 plt.title('Base data')67
68 #原始数据
69 plt.figure(1)70 num =index[0]71 length =len(subseries[num])72 begin = data[dateseries[num]]+' '+times[dateseries[num]]73 end = data[dateseries[num]+length-1]+' '+times[dateseries[num]+length-1]74 label = begin+'~'+end75 plt.plot((subseries[num]), label=label, linewidth='2')76 plt.legend(loc='upper left')77 plt.title('Similarity data')78
79 #结果集对比
80 plt.figure(2)81 plt.plot(normalization(base),label= basedata,linewidth='2')82 length =len(subseries[num])83 begin = data[dateseries[num]] + ' ' +times[dateseries[num]]84 end = data[dateseries[num] + length - 1] + ' ' + times[dateseries[num] + length - 1]85 label = begin + '~' +end86 plt.plot(normalization(subseries[num]), label=label, linewidth='3')87 plt.legend(loc='lower right')88 plt.title('normal similarity search')89 plt.show()90
91 endTimeStamp=datetime.now()92 print('run time', (endTimeStamp-startTimeStamp).seconds, "s")