1 # -*- coding: utf-8 -*- 2 """ 3 Created on Mon Dec 2 14:49:59 2018 4 5 @author: zhen 6 """ 7 8 import matplotlib.pyplot as plt 9 import numpy as np 10 import pandas as pd 11 from datetime import datetime 12 13 def normal(a): #最大值最小值歸一化 14 return (a - np.min(a)) / (np.max(a) - np.min(a)+0.000001) 15 16 def normalization(x): # np.std:計算矩陣的標准差(方差的算術平方根) 17 return (x - np.mean(x)) / np.std(x) 18 19 def corrcoef(a,b): 20 corrc = np.corrcoef(a,b) # 計算皮爾遜相關系數,用於度量兩個變量之間的相關性,其值介於-1到1之間 21 corrc = corrc[0,1] 22 return (16 * ((1 - corrc) / (1 + corrc)) ** 1) # ** 表示乘方 23 24 startTimeStamp = datetime.now() # 獲取當前時間 25 # 加載數據 26 filename = 'C:/Users/zhen/.spyder-py3/sh000300_2017.csv' 27 # 獲取第一,二列的數據 28 all_date = pd.read_csv(filename,usecols=[0, 1, 3], dtype = 'str') 29 all_date = np.array(all_date) 30 data = all_date[:, 0] 31 times = all_date[:, 1] 32 33 data_points = pd.read_csv(filename,usecols=[3]) 34 data_points = np.array(data_points) 35 data_points = data_points[:,0] #數據 36 37 topk = 10 #只顯示top-10 38 baselen = 100 39 basebegin = 361 40 basedata = data[basebegin]+' '+times[basebegin]+'~'+data[basebegin+baselen-1]+' '+times[basebegin+baselen-1] 41 base = data_points[basebegin:basebegin+baselen]#一天的數據是240個點 42 length = len(data_points) #數據長度 43 44 # 分割片段 45 subseries = [] 46 dateseries = [] 47 for j in range(0,length): 48 if (j < (basebegin - baselen) or j > (basebegin + baselen - 1)) and j <length - baselen: 49 subseries.append(data_points[j:j+baselen]) 50 dateseries.append(j) #開始位置 51 52 # 片段搜索 53 listdistance = [] 54 for i in range(0, len(subseries)): 55 tt = np.array(subseries[i]) 56 distance = corrcoef(base, tt) 57 listdistance.append(distance) 58 59 # 排序 60 index = np.argsort(listdistance,kind='quicksort') #排序,返回排序后的索引序列 61 62 # 顯示,要匹配的數據 63 plt.figure(0) 64 plt.plot((base),label = basedata, linewidth='2') 65 plt.legend(loc='upper left') 66 plt.title('Base data') 67 68 # 原始數據 69 plt.figure(1) 70 num = index[0] 71 length = len(subseries[num]) 72 begin = data[dateseries[num]]+' '+times[dateseries[num]] 73 end = data[dateseries[num]+length-1]+' '+times[dateseries[num]+length-1] 74 label = begin+'~'+end 75 plt.plot((subseries[num]), label=label, linewidth='2') 76 plt.legend(loc='upper left') 77 plt.title('Similarity data') 78 79 # 結果集對比 80 plt.figure(2) 81 plt.plot(normalization(base),label= basedata,linewidth='2') 82 length = len(subseries[num]) 83 begin = data[dateseries[num]] + ' ' + times[dateseries[num]] 84 end = data[dateseries[num] + length - 1] + ' ' + times[dateseries[num] + length - 1] 85 label = begin + '~' + end 86 plt.plot(normalization(subseries[num]), label=label, linewidth='3') 87 plt.legend(loc='lower right') 88 plt.title('normal similarity search') 89 plt.show() 90 91 endTimeStamp=datetime.now() 92 print('run time', (endTimeStamp-startTimeStamp).seconds, "s")
結果: