DTW 的理解思路还是按照动态规划的思路 ,和LeetCode的72题编辑距离以及求最短路径类似。DTW会重复使用序列中的点,从而达到扭曲对齐的.
一般都是用两个指针i,j分别指向两个列表的最后,然后一步步往前走,缩小问题的规模。先计算a[i]和b[j]的两点距离,然后开始移动指针i和j,可以i,j一起移动到i-1,j-1,也可以i或者j只移动一个即i-1,j和 i,j-1。那么dp[i,j]= distance(i,j)+min(dp[i-1,j-1],dp[i-1,j],dp(i,j-1))
- dp[i,j]的含义是存储两个序列a,b的最短路径距离
- dp[i,j]可以由dp[i-1,j],dp[i,j-1],dp[i-1,j-1]推导得到,从三者中找出最小值再加上a[i]和b[j]的两点距离
- base case就是i,j为0的时候,设为无穷大即可
import numpy as np
a = np.random.randint(0,5,5)
b = np.random.randint(0,5,2)
a,b
(array([3, 1, 3, 1, 4]), array([1, 2]))
l1 = len(a)
l2 = len(b)
dp table备忘录
dp = np.full((l1+1,l2+1),fill_value=float('inf'))
dp[0,0]=0
choices记录移动方向,初始化
最终要从dp[i,j]往dp[1,1]的回推
choices = np.full((l1+1,l2+1),fill_value='45')
choices
array([['45', '45', '45'],
['45', '45', '45'],
['45', '45', '45'],
['45', '45', '45'],
['45', '45', '45'],
['45', '45', '45']], dtype='<U2')
计算两点距离
def distance(m,n):
return np.abs((m-n))
DTW
for i in range(1,l1+1):
for j in range(1,l2+1):
which = np.argmin((dp[i-1,j-1],dp[i-1,j],dp[i,j-1]))
if which==0:
pass
elif which==1:
choices[i,j]='up'
else:
choices[i,j] = 'lf'
dp[i,j] = min(dp[i-1,j-1],dp[i-1,j],dp[i,j-1])+distance(a[i-1],b[j-1])
dp
array([[ 0., inf, inf],
[inf, 2., 3.],
[inf, 2., 3.],
[inf, 4., 3.],
[inf, 4., 4.],
[inf, 7., 6.]])
choices
array([['45', '45', '45'],
['45', '45', 'lf'],
['45', 'up', '45'],
['45', 'up', '45'],
['45', 'up', 'up'],
['45', 'up', '45']], dtype='<U2')