原文来自:Sumanth Tambe 博客
背景
memoization记忆化或memoisation是记忆功能“记住”与某些特定输入相对应的结果。使用记忆输入的后续调用将返回记住的结果而不是重新计算结果。
Memoization也被用于其他上下文(以及速度增益以外的目的),例如简单的相互递归下降解析。虽然与缓存有关,但memoization指的是此优化的特定情况,将其与缓存形式(如缓冲或页面替换)区分开来。在一些逻辑编程语言的上下文中,memoization也称为tabling;
“memoization”一词是由Donald Michie于1968年创造的,通常在美国英语中被截断为“备忘录”,因而具有“将[函数的结果]转化为需要记住的东西。“ 虽然“memoization”可能与“memorization”混淆(词源同源词),“memoization”在计算一个专门的意义。
示例
考虑一个简单的斐波纳契程序:
unsigned long fibonacci(unsigned n)
{
return (n < 2) ? n : fibonacci(n - 1) + fibonacci(n - 2);
}
该算法是计算第N个斐波纳契数(N从0开始),上面的计算方式的缓慢的令人沮丧,它做了很多冗余的重新计算。但这个程序的美妙之处在于它非常简单。为了加快速度而不显着改变逻辑,我们可以使用memoization。
使用一些聪明的C ++ 11技术,可以记住这个功能,如下所示。
unsigned long fibonacci(unsigned n)
{
return(n <2)?n:memoized_recursion(fibonacci)(n - 1)+
memoized_recursion(fibonacci)(n - 2);
}
我们自己建立了一个memoized_recursion函数用于内存管理。
template <typename ReturnType,typename ... Args>
std :: function <ReturnType(Args ...)>
memoize(ReturnType(* func)(Args ...))
{
auto cache = std :: make_shared <std :: map <std :: tuple <Args ...>,ReturnType >>();
return([=](Args ... args)mutable {
std :: tuple <Args ...> t(args ...);
if(cache-> find(t)== cache-> end())
(* cache)[t] = func(args ...);
return(* cache)[t];
});
}
函数memoize接受指向自由函数的指针,将其包装在lambda中,并将lambda转换为std :: function。返回std :: function是从创建它的函数返回lambda的常见C ++ 11习惯用法。
如果您熟悉C ++ 11 可变参数模板,那么lambda的实现非常简单。它创建一个参数元组并检查它是否存在于缓存中。在这种情况下,返回存储的结果而不是重新计算它。用于将参数映射到返回值的高速缓存是动态分配的。std :: shared_ptr管理内存。lambda按值复制std :: shared_ptr。只要至少有一个std :: function活着,缓存将保持不变。
可以从不同的地方调用记忆功能。在所调用的任何地方传递memoized函数是相当麻烦的。应该有一种方法来查找函数的memoized版本而不会丢失状态。因此,我们的下一步是在程序的任何位置提供相同的memoized功能。我们需要一个指向记忆的std :: function的函数指针映射。具体来说,我们需要一个std :: unordered_map来快速查找。
template <typename F_ret,typename ... F_args>
std :: function <F_ret(F_args ...)>
memoized_recursion(F_ret(* func)(F_args ...))
{
typedef std :: function <F_ret(F_args .. 。)> FunctionType;
static std :: unordered_map <decltype(func),FunctionType> functor_map;
if(functor_map.find(func)== functor_map.end())
functor_map [func] = memoize(func);
return functor_map [func];
}
这里我介绍我们的递归fibonacci函数调用的“memoized_recursion”函数。它有一个静态的std :: unordered_map。它只是根据函数指针值查找记忆的std :: function。如果找不到,则创建它并将其存储以供后续访问。函数指针是唯一的; 所以没有碰撞可能。以下是如何调用它。
memorized_recursion(fibonacci)(10);
但是解决方案尚未完成。记忆显然很快就建立了状态。如果使用大量参数记忆许多函数,则状态会爆炸性增长。必须有某种方法来回收记忆。
请记住,memoized状态在lambda中增长。动态分配的映射存储相应功能的高速缓存。我们需要访问隐藏在lambda中的对象。Lambda有一个编译器定义的类型,你可以用它来做它只能调用它。那么我们如何清除它正在构建的缓存呢?
答案非常简单!只需将memoizer(lambda)与另一个默认的初始化memoizer分配即可
我们已经有memoize函数,它返回一个默认的初始化memoizer。我们只是分配新的一个代替旧的。以下是新memoized_recursion的外观
enum class Cache : unsigned int { NO_RECLAIM, RECLAIM };
template <typename F_ret, typename... F_args>
std::function<F_ret (F_args...)>
memoized_recursion(F_ret (*func)(F_args...), Cache c = Cache::NO_RECLAIM)
{
typedef std::function<F_ret (F_args...)> FunctionType;
static std::unordered_map<decltype(func), FunctionType> functor_map;
if(Cache::RECLAIM == c)
return functor_map[func] = memoize(func);
if(functor_map.find(func) == functor_map.end())
functor_map[func] = memoize(func);
return functor_map[func];
}
我正在使用强类型枚举来传递程序员清除缓存的意图。这就是你怎么称呼它。
memoized_recursion(fibonacci,Cache :: RECLAIM);
纯静态记忆器
严格地说,在memoized_recursion函数中使用std :: unordered_map是没有必要的。它是函数指针与其相应的memoized函数对象(包裹在std :: function中的lambda)的O(1)映射。实现相同映射的另一种方法是使用纯静态memoizer。我称之为纯粹,因为没有像std :: unordered_map那样的动态分配。仿函数仅存储为静态对象。只有当memoized_recursion可以单独用于所有可能的自由函数时,才有可能实现这一点!请注意,每个自由函数都保证具有唯一的指针值,指针值可用作模板参数。所以这里是如何在static_memoizer中组合所有这些东西。
template <typename Sig,Sig funcptr>
struct static_memoizer;
template <typename F_ret,typename ... F_args,F_ret(* func)(F_args ...)>
struct static_memoizer <F_ret(*)(F_args ...),func>
{
static
std :: function <F_ret(F_args。 ..)>&
get(Cache c = Cache :: NO_RECLAIM)
{
static std :: function <F_ret(F_args ...)> mfunc(memoize(func));
if(Cache :: RECLAIM == c)
mfunc = memoize(func);
return mfunc;
}
};
#define STATIC_MEMOIZER(func)static_memoizer <decltype(&func),&func>
STATIC_MEMOIZER宏简化了static_memoizer的使用。它使用decltype提取函数指针的类型,并将其(类型)作为模板的第一个参数传递。第二个参数是实际的函数指针。将函数指针作为模板参数传递非常重要,因为许多函数可能共享相同的签名但从不使用相同的指针。
static_memoizer使用的静态对象与memoized_recursion不同。所以我们要重写fibonacci函数来使用static_memoizer。
unsigned long fibonacci(unsigned n)
{
return (n < 2) ? n :
STATIC_MEMOIZER(fibonacci)::get()(n - 1) +
STATIC_MEMOIZER(fibonacci)::get()(n - 2);
}
完整代码
#include <vector>
#include <iostream>
#include <functional>
#include <unordered_map>
#include <map>
#include <memory>
template <typename ReturnType, typename... Args>
std::function<ReturnType (Args...)>
memoize(ReturnType (*func) (Args...))
{
auto cache = std::make_shared<std::map<std::tuple<Args...>, ReturnType>>();
return ([=](Args... args) mutable {
std::tuple<Args...> t(args...);
if (cache->find(t) == cache->end())
(*cache)[t] = func(args...);
return (*cache)[t];
});
}
enum class Cache : unsigned int { NO_RECLAIM, RECLAIM };
template <typename F_ret, typename... F_args>
std::function<F_ret (F_args...)>
memoized_recursion(F_ret (*func)(F_args...), Cache c = Cache::NO_RECLAIM)
{
typedef std::function<F_ret (F_args...)> FunctionType;
static std::unordered_map<decltype(func), FunctionType> functor_map;
if(Cache::RECLAIM == c)
return functor_map[func] = memoize(func);
if(functor_map.find(func) == functor_map.end())
functor_map[func] = memoize(func);
return functor_map[func];
}
template <typename Sig, Sig funcptr>
struct static_memoizer;
template <typename F_ret, typename... F_args, F_ret (*func)(F_args...)>
struct static_memoizer<F_ret (*)(F_args...), func>
{
static
std::function<F_ret (F_args...)> &
get(Cache c = Cache::NO_RECLAIM)
{
static std::function<F_ret (F_args...)> mfunc (memoize(func));
if(Cache::RECLAIM == c)
mfunc = memoize(func);
return mfunc;
}
};
#define STATIC_MEMOIZER(func) static_memoizer<decltype(&func), &func>::get()
unsigned long fibonacci(unsigned n)
{
return (n < 2) ? n :
memoized_recursion(fibonacci)(n - 1) +
memoized_recursion(fibonacci)(n - 2);
}
unsigned long fib(unsigned n)
{
return (n < 2) ? n :
STATIC_MEMOIZER(fib)(n - 1) +
STATIC_MEMOIZER(fib)(n - 2);
}
int main(void) {
for(int i=0;i < 300; i++)
std::cout << fibonacci(i) << " ";
std::cout << std::endl;
for(int i=0;i < 300; i++)
std::cout << fib(i) << " ";
memoized_recursion(fibonacci, Cache::RECLAIM);
}
运行结果: