1 #-*- coding: utf-8 -*-
2 """
3 Created on Tue Dec 4 08:53:08 20184
5 @author: zhen6 """
7 from dtw importfastdtw8 importmatplotlib.pyplot as plt9 importnumpy as np10 importpandas as pd11 importthreading12 importtime13 from datetime importdatetime14
15 def normalization(x): #np.std:计算矩阵的标准差(方差的算术平方根)
16 return (x - np.mean(x)) /np.std(x)17
18 defcorrcoef(a,b):19 corrc = np.corrcoef(a,b) #计算皮尔逊相关系数,用于度量两个变量之间的相关性,其值介于-1到1之间
20 corrc = corrc[0,1]21 return (16 * ((1 - corrc) / (1 + corrc)) ** 1) #** 表示乘方
22
23 print("begin Main Thread")24 startTimeStamp = datetime.now() #获取当前时间
25 #加载数据
26 filename = 'C:/Users/zhen/.spyder-py3/sh000300_2017.csv'
27 #获取第一,二列的数据
28 all_date = pd.read_csv(filename,usecols=[0, 1], dtype = 'str')29 all_date =np.array(all_date)30 data =all_date[:, 0]31 times = all_date[:, 1]32
33 data_points = pd.read_csv(filename,usecols=[3])34 data_points =np.array(data_points)35 data_points = data_points[:,0] #数据
36
37 topk = 10 #只显示top-10
38 baselen = 100 #假设在50到150之间变化
39 basebegin = 361
40 basedata = data[basebegin]+' '+times[basebegin]+'~'+data[basebegin+baselen-1]+' '+times[basebegin+baselen-1]41 length = len(data_points) #数据长度
42
43 #定义自定义线程类
44 classThread_Local(threading.Thread):45 def __init__(self, thread_id, name, counter):46 threading.Thread.__init__(self)47 self.thread_id =thread_id48 self.name =name49 self.counter =counter50 self.__running = threading.Event() #标识停止线程
51 self.__running.set() #设置为True
52
53 defrun(self):54 print("starting %s" %self.name)55 split_data(self, self.counter) #执行代码逻辑
56
57 defstop(self):58 self.__running.clear()59
60 #分割片段并执行匹配,多线程
61 defsplit_data(self, split_len):62 base = data_points[basebegin:basebegin+baselen] #获取初始要匹配的数据
63 subseries =[]64 dateseries =[]65 for j inrange(0, length):66 if (j < (basebegin - split_len) or j > (basebegin + split_len - 1)) and j
69 search(self, subseries, base, dateseries) #调用模式匹配
70
71 #定义结果变量
72 result =[]73 base_list =[]74 date_list =[]75 defsearch(self, subseries, base, dateseries):76 #片段搜索
77 listdistance =[]78 for i inrange(0, len(subseries)):79 tt =np.array(subseries[i])80 dist, cost, acc, path = fastdtw(base, tt, dist='euclidean')81 listdistance.append(dist)82 #distance = corrcoef(base, tt)
83 #listdistance.append(distance)
84 #排序
85 index = np.argsort(listdistance, kind='quicksort') #排序,返回排序后的索引序列
86 result.append(subseries[index[0]])87 print("result length is %d" %len(result))88 base_list.append(base)89 date_list.append(dateseries[index[0]])90 #关闭线程
91 self.stop()92
93 #变换数据(收缩或扩展),生成50到150之间的数据,间隔为10
94 loc =095 for split_len in range(round(0.5 * baselen), round(1.5 * baselen), 10):96 #执行匹配
97 thread = Thread_Local(1, "Thread" +str(loc), split_len)98 loc += 1
99 #开启线程
100 thread.start()101
102 boo = 1
103
104 while(boo >0):105 if(len(result) < 10):106 if(boo % 100 ==0):107 print("has running %d s" %boo)108 boo += 1
109 time.sleep(1)110 else:111 boo =0112
113 #片段搜索
114 listdistance =[]115 for i inrange(0, len(result)):116 tt =np.array(result[i])117 dist, cost, acc, path = fastdtw(base_list[i], tt, dist='euclidean')118 #distance = corrcoef(base_list[i], tt)
119 listdistance.append(dist)120 #最终排序
121 index = np.argsort(listdistance, kind='quicksort') #排序,返回排序后的索引序列
122 print("closed Main Thread")123 endTimeStamp =datetime.now()124 #结果集对比
125 plt.figure(0)126 plt.plot(normalization(base_list[index[0]]),label= basedata,linewidth='2')127 length =len(result[index[0]])128 begin = data[date_list[index[0]]] + ' ' +times[date_list[index[0]]]129 end = data[date_list[index[0]] + length - 1] + ' ' + times[date_list[index[0]] + length - 1]130 label = begin + '~' +end131 plt.plot(normalization(result[index[0]]), label=label, linewidth='2')132 plt.legend(loc='lower right')133 plt.title('normal similarity search')134 plt.show()135 print('run time', (endTimeStamp-startTimeStamp).seconds, "s")