c++ python @cache 的模仿


第一次看到python @cache 魔法的时候是在看 @灵茶山山神 在B站讲题, 当时的反应就是一脸震惊, 看着我c++ dfs 没有记忆化的代码陷入了沉思。

那么c++ 可不可以相对方便的像 python那样 稍微加亿点点代码实现 从暴力搜索到 记忆化搜索的 华丽的转身了,那就有了下面这个拙劣的模仿

附加代码

class null_param {
};
template<typename Sig, class F>
class memoize_helper;

template<typename R, typename... Args, class F>
class memoize_helper<R(Args...), F> {
private:
    using function_type = F;
    using args_tuple_type = tuple<Args...>;

    function_type f;
    mutable map<args_tuple_type, R> cache;

public:
    template<class Function>
    memoize_helper(Function &&f, null_param) : f(std::forward<Function>(f)) {}

    memoize_helper(const memoize_helper &other) : f(other.f) {}

    template<class ...InnerArgs>
    R operator()(InnerArgs &&... args) const {
        auto args_tuple = make_tuple(std::forward<InnerArgs>(args)...);
        auto it = cache.find(args_tuple);
        if (it != cache.end()) {
            return it->second;
        }
        return cache[args_tuple] = f(*this, std::forward<InnerArgs>(args)...);
    }
};

template<int Dim, typename R, class F>
class memoize_vec_helper;

// 一维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<1, R(Args...), F> {
private:
    using function_type = F;
    function_type f;
    mutable vector<R> cache;
    R dv;

public:
    template<class Function>
    memoize_vec_helper(Function &&f, int sz, R r, R bad) : f(std::forward<Function>(f)), cache(sz, r), dv(r) {}

    template<class InnerArgs>
    R operator()(InnerArgs && arg) const {
        if (cache[arg] != dv) {
            return cache[arg];
        }
        return cache[arg] = f(*this, std::forward<InnerArgs>(arg));
    }
};

// 二维数组的特化
template<typename R, typename ...Args, class F>
class memoize_vec_helper<2, R(Args...), F> {
private:
    using function_type = F;
    function_type f;
    mutable vector<vector<R>> cache;
    R dv;

public:
    template<class Function>
    memoize_vec_helper(Function &&f, int fs, int ss, R r) : f(std::forward<Function>(f)), cache(fs, vector<R>(ss, r)), dv(r){}

    template<typename IndexType>
    R operator()(IndexType first, IndexType second) const {
        // 这里需要修改,因为arg现在是一个pair或者tuple
        if (cache[first][second] != dv) {
            return cache[first][second];
        }
        return cache[first][second] = f(*this, first, second);
    }
};

/**
 * @brief  cache使用map
 * 
 * @tparam Sig 
 * @tparam F 
 * @param f 
 * @return memoize_helper<Sig, std::decay_t<F>> 
 */
template<class Sig, class F>
memoize_helper<Sig, std::decay_t<F>> cache(F &&f) {
    return memoize_helper<Sig, std::decay_t<F>>(std::forward<F>(f), null_param{});
}

// 创建一维数组的函数
/***
 * @brief  chache 使用一维数组
 * @param f 函数
 * @param sz 数组大小
 * @param default_value 默认值
 * @return
 */
template<class Sig, class F>
memoize_vec_helper<1, Sig, std::decay_t<F>> cache_vec(F &&f, int sz, int default_value) {
    return memoize_vec_helper<1, Sig, std::decay_t<F>>(std::forward<F>(f), sz, default_value);
}

// 创建二维数组的函数
/***
 * @brief cache 使用二维数组
 * @param f 函数
 * @param sz 第一维数组大小
 * @param sz2 第二维数组大小
 * @param default_value 默认值
 * @return
 */
template<class Sig, class F>
memoize_vec_helper<2, Sig, std::decay_t<F>> cache_vec2(F &&f, int sz, int sz2, int default_value) {
    return memoize_vec_helper<2, Sig, std::decay_t<F>>(std::forward<F>(f), sz, sz2, default_value);
}

使用方式

1、将上面的代码拷贝到代码文件
2、写出暴力dfs的代码
3、修改 , 我们会发现, 代码基本上一样的

比如暴力搜索代码如下
function<R(Args...)> dfs = [&](Args... args)->R {
    // 逻辑代码
};
调用方式 
dfs(args...);

修改如下

auto memo = [&](auto& dfs, Args... args)->R {
    // 逻辑代码, 这里的代码和之前不变
};
// 调用方式
cache<R(Args...)>(memo)(args...)

R 是函数的返回值
Args... 是参数列表的类型
args... 实际的调用参数

补充 使用数组代替map做为缓存

上面的cache 使用的是map, 但是更多时候我们可以使用一维数组 和二维数组

使用一维数组的方式

cache_vec<R(int)>(memo)(int, #数组大小, #数组默认值); 

使用二维数组的方式

cache_vec2<R(int, int)>(memo)(int, int, #数组大小, #第二维数组大小, #数组默认值); 

举例 打家劫社舍问题

function<int(int)> dfs = [&](int n) -> int {
    if(n < 0) return 0;
        return max(dfs(n - 1), dfs(n - 2) + nums[n]);
    };
dfs(nums.size() - 1);   
 

记忆化代码
auto memo = [&](auto& dfs, int n) -> int {
    if(n < 0) return 0;
    return max(dfs(n - 1), dfs(n - 2) + nums[n]);
};
return cache<int(int)>(memo)(nums.size() - 1); 


实战,请详细比较 —brute的代码和 记忆话后的代码

爬楼梯问题 70.爬楼梯

class Solution {
public:
    int climbStairs(int n) {
        auto memo = [&](auto& dfs, int i)->int {
            if (i <= 2) {
                return i;
            }
            return dfs(i - 1) + dfs(i - 2);
        };
        return cache<int(int)>(memo)(n);
    }
};

打家劫舍问题 198.打家劫舍

class Solution {
public:
    // 暴力代码
    int rob_brute(vector<int>& nums) {
        function<int(int)> dfs = [&](int n) -> int {
            if(n < 0) return 0;
            return max(dfs(n - 1), dfs(n - 2) + nums[n]);
        };
        return dfs(nums.size() - 1);   
    }

    int rob(vector<int>& nums) {
        auto memo = [&](auto& dfs, int n) -> int {
            if(n < 0) return 0;
            return max(dfs(n - 1), dfs(n - 2) + nums[n]);
        };
        return cache<int(int)>(memo)(nums.size() - 1);   
    }
};

网格问题 62. 不同路径

class Solution {
public:
    int uniquePaths(int m, int n) {
        auto memo = [&](auto& dfs, int i, int j) {
            if (i >= m || j >= n) {
                return 0;
            }
            if (i == m - 1 and j == n - 1) {
                return 1;
            }
            return dfs(i + 1, j) + dfs(i, j + 1);
        };
        return cache<int(int, int)>(memo)(0, 0);
    }
};

01背包问题 416.分割等和子集

class Solution {
public:
    bool canPartition_brute(vector<int>& nums) {
        if (nums.size() <= 1) {
            return false;
        }
        int total = accumulate(nums.begin(), nums.end(), 0);
        if ((total & 0x1) != 0) {
            return false;
        }
        ranges::sort(nums);
        int half = total >> 1;
        function<int(int, int)> dfs = [&](int i, int j)->int {
            if (j == 0) {
                return 1;
            }
            if (i >= nums.size() || j < 0 || j < nums[i]) {
                return 0;
            }
            return dfs(i + 1, j) || dfs(i + 1, j - nums[i]);
        };
        return dfs(0, half);
    }

    bool canPartition(vector<int>& nums) {
        if (nums.size() <= 1) {
            return false;
        }
        int total = accumulate(nums.begin(), nums.end(), 0);
        if ((total & 0x1) != 0) {
            return false;
        }
        ranges::sort(nums);
        int half = total >> 1;
        auto memo = [&](auto& dfs, int i, int j)->int {
            if (j == 0) {
                return 1;
            }
            if (i >= nums.size() || j < 0 || j < nums[i]) {
                return 0;
            }
            return dfs(i + 1, j) || dfs(i + 1, j - nums[i]);
        };
        return cache<int(int, int)>(memo)(0, half);
    }
};
  • 5
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值