第一次看到python @cache 魔法的时候是在看 @灵茶山山神 在B站讲题, 当时的反应就是一脸震惊, 看着我c++ dfs 没有记忆化的代码陷入了沉思。
那么c++ 可不可以相对方便的像 python那样 稍微加亿点点代码实现 从暴力搜索到 记忆化搜索的 华丽的转身了,那就有了下面这个拙劣的模仿
附加代码
class null_param {
};
template<typename Sig, class F>
class memoize_helper;
template<typename R, typename... Args, class F>
class memoize_helper<R(Args...), F> {
private:
using function_type = F;
using args_tuple_type = tuple<Args...>;
function_type f;
mutable map<args_tuple_type, R> cache;
public:
template<class Function>
memoize_helper(Function &&f, null_param) : f(std::forward<Function>(f)) {}
memoize_helper(const memoize_helper &other) : f(other.f) {}
template<class ...InnerArgs>
R operator()(InnerArgs &&... args) const {
auto args_tuple = make_tuple(std::forward<InnerArgs>(args)...);
auto it = cache.find(args_tuple);
if (it != cache.end()) {
return it->second;
}
return cache[args_tuple] = f(*this, std::forward<InnerArgs>(args)...);
}
};
template<int Dim, typename R, class F>
class memoize_vec_helper;
// 一维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<1, R(Args...), F> {
private:
using function_type = F;
function_type f;
mutable vector<R> cache;
R dv;
public:
template<class Function>
memoize_vec_helper(Function &&f, int sz, R r, R bad) : f(std::forward<Function>(f)), cache(sz, r), dv(r) {}
template<class InnerArgs>
R operator()(InnerArgs && arg) const {
if (cache[arg] != dv) {
return cache[arg];
}
return cache[arg] = f(*this, std::forward<InnerArgs>(arg));
}
};
// 二维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<2, R(Args...), F> {
private:
using function_type = F;
function_type f;
mutable vector<vector<R>> cache;
R dv;
public:
template<class Function>
memoize_vec_helper(Function &&f, int fs, int ss, R r) : f(std::forward<Function>(f)), cache(fs, vector<R>(ss, r)), dv(r){}
template<typename IndexType>
R operator()(IndexType first, IndexType second) const {
// 这里需要修改,因为arg现在是一个pair或者tuple
if (cache[first][second] != dv) {
return cache[first][second];
}
return cache[first][second] = f(*this, first, second);
}
};
/**
* @brief cache使用map
*
* @tparam Sig
* @tparam F
* @param f
* @return memoize_helper<Sig, std::decay_t<F>>
*/
template<class Sig, class F>
memoize_helper<Sig, std::decay_t<F>> cache(F &&f) {
return memoize_helper<Sig, std::decay_t<F>>(std::forward<F>(f), null_param{});
}
// 创建一维数组的函数
/***
* @brief chache 使用一维数组
* @param f 函数
* @param sz 数组大小
* @param default_value 默认值
* @return
*/
template<class Sig, class F>
memoize_vec_helper<1, Sig, std::decay_t<F>> cache_vec(F &&f, int sz, int default_value) {
return memoize_vec_helper<1, Sig, std::decay_t<F>>(std::forward<F>(f), sz, default_value);
}
// 创建二维数组的函数
/***
* @brief cache 使用二维数组
* @param f 函数
* @param sz 第一维数组大小
* @param sz2 第二维数组大小
* @param default_value 默认值
* @return
*/
template<class Sig, class F>
memoize_vec_helper<2, Sig, std::decay_t<F>> cache_vec2(F &&f, int sz, int sz2, int default_value) {
return memoize_vec_helper<2, Sig, std::decay_t<F>>(std::forward<F>(f), sz, sz2, default_value);
}
使用方式
1、将上面的代码拷贝到代码文件
2、写出暴力dfs的代码
3、修改 , 我们会发现, 代码基本上一样的
比如暴力搜索代码如下
function<R(Args...)> dfs = [&](Args... args)->R {
// 逻辑代码
};
调用方式
dfs(args...);
修改如下
auto memo = [&](auto& dfs, Args... args)->R {
// 逻辑代码, 这里的代码和之前不变
};
// 调用方式
cache<R(Args...)>(memo)(args...)
R 是函数的返回值
Args... 是参数列表的类型
args... 实际的调用参数
补充 使用数组代替map做为缓存
上面的cache 使用的是map, 但是更多时候我们可以使用一维数组 和二维数组
使用一维数组的方式
cache_vec<R(int)>(memo)(int, #数组大小, #数组默认值);
使用二维数组的方式
cache_vec2<R(int, int)>(memo)(int, int, #数组大小, #第二维数组大小, #数组默认值);
举例 打家劫社舍问题
function<int(int)> dfs = [&](int n) -> int {
if(n < 0) return 0;
return max(dfs(n - 1), dfs(n - 2) + nums[n]);
};
dfs(nums.size() - 1);
记忆化代码
auto memo = [&](auto& dfs, int n) -> int {
if(n < 0) return 0;
return max(dfs(n - 1), dfs(n - 2) + nums[n]);
};
return cache<int(int)>(memo)(nums.size() - 1);
实战,请详细比较 —brute的代码和 记忆话后的代码
爬楼梯问题 70.爬楼梯
class Solution {
public:
int climbStairs(int n) {
auto memo = [&](auto& dfs, int i)->int {
if (i <= 2) {
return i;
}
return dfs(i - 1) + dfs(i - 2);
};
return cache<int(int)>(memo)(n);
}
};
打家劫舍问题 198.打家劫舍
class Solution {
public:
// 暴力代码
int rob_brute(vector<int>& nums) {
function<int(int)> dfs = [&](int n) -> int {
if(n < 0) return 0;
return max(dfs(n - 1), dfs(n - 2) + nums[n]);
};
return dfs(nums.size() - 1);
}
int rob(vector<int>& nums) {
auto memo = [&](auto& dfs, int n) -> int {
if(n < 0) return 0;
return max(dfs(n - 1), dfs(n - 2) + nums[n]);
};
return cache<int(int)>(memo)(nums.size() - 1);
}
};
网格问题 62. 不同路径
class Solution {
public:
int uniquePaths(int m, int n) {
auto memo = [&](auto& dfs, int i, int j) {
if (i >= m || j >= n) {
return 0;
}
if (i == m - 1 and j == n - 1) {
return 1;
}
return dfs(i + 1, j) + dfs(i, j + 1);
};
return cache<int(int, int)>(memo)(0, 0);
}
};
01背包问题 416.分割等和子集
class Solution {
public:
bool canPartition_brute(vector<int>& nums) {
if (nums.size() <= 1) {
return false;
}
int total = accumulate(nums.begin(), nums.end(), 0);
if ((total & 0x1) != 0) {
return false;
}
ranges::sort(nums);
int half = total >> 1;
function<int(int, int)> dfs = [&](int i, int j)->int {
if (j == 0) {
return 1;
}
if (i >= nums.size() || j < 0 || j < nums[i]) {
return 0;
}
return dfs(i + 1, j) || dfs(i + 1, j - nums[i]);
};
return dfs(0, half);
}
bool canPartition(vector<int>& nums) {
if (nums.size() <= 1) {
return false;
}
int total = accumulate(nums.begin(), nums.end(), 0);
if ((total & 0x1) != 0) {
return false;
}
ranges::sort(nums);
int half = total >> 1;
auto memo = [&](auto& dfs, int i, int j)->int {
if (j == 0) {
return 1;
}
if (i >= nums.size() || j < 0 || j < nums[i]) {
return 0;
}
return dfs(i + 1, j) || dfs(i + 1, j - nums[i]);
};
return cache<int(int, int)>(memo)(0, half);
}
};