fastdtw计算两个股票序列之间的相似性(最小距离)
1.动态时间扭曲 (DTW)
Dynamic time warping:动态时间扭曲 (DTW) 是一种在两个时间序列之间找到最佳对齐的技术,其中一个时间序列可以通过拉伸或收缩其时间轴来非线性地“扭曲”。 这种比对可用于找到对应的区域或确定两个时间序列之间的相似性。
DTW 经常用于语音识别,以确定两个波形是否代表相同的口语短语。 在语音波形中,每个语音的持续时间和声音之间的间隔是允许变化的,但整体语音波形必须相似。
既然是时间序列的数据对比,分析相似度,那么当然也可以用于分析两个股票走势的相似度。
2. 随机序列验证测试
用随机数测试:
import numpy as np
import fastdtw
# 定义两个时间序列
series1 = np.random.rand(20) # 使用 rand 生成 [0, 1) 内的随机数
series2 = np.random.rand(20) # 同样生成 [0, 1) 内的随机数
# 使用fastdtw计算这两个序列之间的最小距离
distance, path = fastdtw.fastdtw(series1, series2)
print('FastDTW distance:', distance)
FastDTW distance: 4.126319461445138
数据范围,变化不大,距离很小,相似度高。
3.时间序列验证测试
通过两个曲线验证一下距离,理解fastdtw函数:
# 数据标准化函数
def standardize_data(df):
return (df - df.mean()) / df.std()
def show_plot(amplitude_a,amplitude_b):
# 使用fastdtw计算这两个序列之间的最小距离
distance, path = fastdtw.fastdtw(standardize_data(amplitude_a), standardize_data(amplitude_b))
print('FastDTW distance:', distance)
# 计算相似度
distance = m._dtw_distance(standardize_data(amplitude_a), standardize_data(amplitude_b))
print('FastDTW distance:', distance)
# 创建一个大图,包含两个子图
fig, axs = plt.subplots(2, 1, figsize=(12, 8))
# 第一个子图:展示标准化后的数据
axs[0].plot(time, standardize_data(amplitude_a), label='amplitude_a_stand')
axs[0].plot(time, standardize_data(amplitude_b), label='amplitude_b_stand')
axs[0].set_title('Standardized Data')
axs[0].set_ylabel('Amplitude (Standardized)')
axs[0].set_xlabel('Time')
axs[0].legend()
# 第二个子图:展示原始数据
axs[1].plot(time, amplitude_a, label='amplitude_a')
axs[1].plot(time, amplitude_b, label='amplitude_b')
axs[1].set_title('Original Data')
axs[1].set_ylabel('Amplitude')
axs[1].set_xlabel('Time')
axs[1].legend()
# 添加整个图的标题(可选)
fig.suptitle('DTW distance between %s and %s is %.2f' % ('sin', 'cos', distance))
# 显示图形
plt.tight_layout() # 自动调整子图参数, 使之填充整个图像区域
plt.show()
time = np.linspace(0,10,1000)
# 定义两个时间序列曲线
amplitude_a = np.sin(time)
amplitude_b = np.cos(time)
show_plot(amplitude_a,amplitude_b)
FastDTW distance: 180.4045195245856
图示:
稍微复杂一点的曲线
time = np.linspace(0,10,1000)
amplitude_a = 8*np.sin(time+np.pi/6) + np.sin(time)/2
amplitude_b = 2*np.cos(time+np.pi/12) + 2*np.sin(time)/2
FastDTW distance: 70.96382215058884
可以看到振幅不同,经过标准化调整后,比cos和sin相似度更大。
距离比正弦曲线和余弦曲线近。
4.指数和股票序列之间相似性
下面做点有趣的测试,用沪深300的指数,测试范围两年的数据,在股票中寻找走势与之相似的股票。
选择相似度高的,距离小于70的,展示一下:
一共8个股票
less_perc =
{'002080.SZ': 67.35405249048497, '300055.SZ': 69.23801473990225, '300894.SZ': 56.61913877059769, '300999.SZ': 63.85063120595672, '600010.SH': 60.95074038146817, '600315.SH': 67.58673618181477, '600597.SH': 69.34313634300776, '603486.SH': 62.23227702713981}
代码:
# 数据标准化函数
def standardize_data(df):
return (df - df.mean()) / df.std()
def show_distance(time,index_code,stock_code,index_name,stock_name):
# 使用fastdtw计算这两个序列之间的最小距离
distance, path = fastdtw.fastdtw(standardize_data(index_code), standardize_data(stock_code))
print('FastDTW distance:', distance)
# 创建一个大图,包含两个子图
fig, axs = plt.subplots(3, 1, figsize=(12, 8))
# 第一个子图:展示标准化后的数据
axs[0].plot(time, standardize_data(index_code), label='index_code_stand')
axs[0].plot(time, standardize_data(stock_code), label='stock_code_stand')
axs[0].set_title('Standardized Data')
axs[0].set_ylabel('Amplitude (Standardized)')
axs[0].set_xlabel('Time')
axs[0].legend()
# 第二个子图:展示原始数据
axs[1].plot(time, index_code, label='index_code %s' % index_name)
axs[1].set_title('Original Index Data %s' % index_name )
axs[1].set_ylabel('Amplitude')
axs[1].set_xlabel('Time')
axs[1].legend()
# 第三个子图:展示原始数据
axs[2].plot(time, stock_code, label='stock_code %s' % stock_name)
axs[2].set_title('Original Stock Data %s' % stock_name)
axs[2].set_ylabel('Amplitude')
axs[2].set_xlabel('Time')
axs[2].legend()
# 添加整个图的标题(可选)
fig.suptitle('DTW distance between %s and %s is %.2f' % (index_name,stock_name, distance))
# 显示图形
plt.tight_layout() # 自动调整子图参数, 使之填充整个图像区域
plt.show()
index_code = get_index('000300.SH','2022-01-01','2024-01-01')
for key in less_perc:
print(key,less_perc[key])
stock_code=get_code(key,'2022-01-01','2024-01-01')
print(len(stock_code['close'].values),len(index_code['close'].values))
# 两个数组的长度相同,才能比较,建立同样的时间线
if len(stock_code['close'].values) == len(index_code['close'].values) :
time = np.linspace(0,10,len(index_code['close'].values))
show_distance(time,index_code['close'],stock_code['close'],'000300.SH',key)
挑选两个图示效果:
002080.SZ 67.35405249048497
484 484
FastDTW distance: 62.78944947058509
300894.SZ 56.61913877059769
484 484
FastDTW distance: 56.984793731052015
通过两个图示对比,走势的相似度还是比较接近的,拟合的比较完美。
用指数和市场全量对比一遍,还是需要点时间。
fastdtw可以做点有趣的事情!