编辑距离,是指将一个字符串通过修改,删除,增加三种操作变化为另外一个字符串,编辑距离问题(levensthein)这个过程中要求这三种操作数量最少。编辑距离公式如下:
,其中,
。
由此公式可以推导出来递归方法:
该方法如下:
#include <iostream>
#include <iterator>
#include <string>
namespace wscad {
template <typename Iter>
using DistanceType = typename std::iterator_traits <Iter> :: difference_type ;
template <typename Iter> inline auto
pred (Iter i) -> Iter
{
return (-- i) ;
}
template <typename Iter> inline auto
dist (Iter first, Iter last) -> DistanceType <Iter>
{
return std::distance (first, last) ;
}
template <typename T> inline auto
min (const T& a, const T& b, const T& c) -> const T&
{
return std::min (std::min (a, b), c) ;
}
template <typename Iter> auto
lev_dist (Iter f1, Iter l1, Iter f2, Iter l2) -> DistanceType <Iter>
{
using D = DistanceType <Iter> ;
D cost {} ;
if (f1 == l1) return dist (f2, l2) ;
if (f2 == l2) return dist (f1, l1) ;
cost = (*pred (l1) == *pred (l2)) ? D (0) : D (1) ;
return min
( lev_dist (f1, pred (l1), f2, l2 ) + D (1)
, lev_dist (f1, l1, f2, pred (l2)) + D (1)
, lev_dist (f1, pred (l1), f2, pred (l2)) + cost ) ;
}
} // end namespace wscad
int main (int argc, char const* argv [])
{
std::string s ;
std::string t ;
std::getline (std::cin, s) ;
std::getline (std::cin, t) ;
std::cout
<< wscad::lev_dist
( std::begin (s)
, std::end (s)
, std::begin (t)
, std::end (t) )
<< std::endl ;
}
该递归方法直接有上述公式写出来的,由于不断地递归,导致算法效率严重降低,故此可以根据上述公式推导出非递归方法,利用矩阵进行计算,一般的矩阵计算需要维护m*n的二维数组,但是以下代码并不是通过维护一个二维数组矩阵,只是维护了三个一维数组(m+1),这样既可以减少空间使用,同事也为下面的并行计算提供了基础。
该非递归方法如下:
void serial_operation(int m, int n, int index, int *L, int *PL, int *PP, std::string str1, std::string str2)
{
int cow, row, cost, buffer_size;
while (index < m + n + 1)
{
if (index < n + 1)
{
if (index > m)
{
buffer_size = m;
L[0] = index;
}
else
{
buffer_size = index - 1;
L[0] = index;
L[index] = index;
}
for (int i = 1; i <= buffer_size; i++)
{
cow = index - i;
row = i;
if (str1[cow - 1] == str2[row - 1])
cost = 0;
else
cost = 1;
L[i] = min(PL[i - 1] + 1, PL[i] + 1, PP[i - 1] + cost);
}
}
else if (index > n + 1)
{
buffer_size = m + n + 1 - index;
for (int i = 1; i <= buffer_size; i++)
{
row = index - n + i - 1;
cow = n - i + 1;
if (str1[cow - 1] == str2[row - 1])
cost = 0;
else
cost = 1;
L[i - 1] = min(PL[i - 1] + 1, PL[i] + 1, PP[i] + cost);
}
}
else
{
buffer_size = m;
for (int i = 1; i <= buffer_size; i++)
{
cow = index - i;
row = i;
if (str1[cow - 1] == str2[row - 1])
cost = 0;
else
cost = 1;
L[i - 1] = min(PL[i - 1] + 1, PL[i] + 1, PP[i - 1] + cost);
}
}
for (int i = 0; i < m + 1; i++)
{
PP[i] = PL[i];
PL[i] = L[i];
}
index++;
}
}
其中m是行数,也就是比较短的那个字符串str2的长度,
n是列数,较长字符串str1的长度
index是斜边的索引,范围理论上从0到m+n,但是这里初始化就是2
L是需要计算的斜边
PL是L斜边前面的一条斜边
PP是PL斜边前面的一条斜边
代码逻辑如下图所示: