【C/C++无聊练手(四)】用C++模板写一个带迭代器的树状数组,用于求前缀和

前言

最近突发奇想,想到了一个神奇的场景:假设有 N N N 个人,每个人初始都有 1 1 1 块钱。每次随机给其中一个人 1 1 1 块钱,每个人获得钱的概率和自己当前有的钱数成正比(本来想模拟马太效应),那么最后每个人的钱对应的概率分布是多少?当然,最后仿真结果出来是指数分布,这也很好解释,因为这个过程有无记忆性,但这不是重点,就不多聊了。

这里的编程难点在于如何实现让每个人获得钱的概率和自己当前有的钱数成正比,一个暴力的做法就是每次迭代时求一次前缀和得到概率分布函数,然后在随机一个均匀分布的整数,通过 lower_bound 函数给出应该获得钱的对象的 index 。但显然这样一次迭代的时间复杂度是 O ( N ) O\left(N\right) O(N) ,例如执行 N N N 次迭代就要平方复杂度了,实在让人难以接受。

当然,对于固定的数组,可以以 O ( 1 ) O\left(1\right) O(1) 复杂度求解前缀和(如303. 区域和检索 - 数组不可变),然后再执行一个 O ( log ⁡ N ) O\left(\log N\right) O(logN) 复杂度的 lower_bound 。实际上STL库的 discrete_distribution 函数就是这么实现的,但这样的缺点在于不可动态调整,每次调整还是得重新计算一遍前缀和。

说到可动态修改的前缀和,那当然就是大名鼎鼎的线段树树状数组了(见307. 区域和检索 - 数组可修改)。以树状数组为例,其可以 O ( log ⁡ N ) O\left(\log N\right) O(logN) 时间复杂度进行修改、求前缀和操作。因此,反正闲着也是闲着,不如拿这玩意练练手,刚好作为C++的复健运动。

当然,既然要做,那就得做到足够好和完善。本项目通过模板类写了一个比较完备的树状数组类,并且按照STL的规范写了一个随机访问求和结果的迭代器,用于与STL算法交互(例如,如果要查找区间 [first, last) 中不小于 val 的第一个元素,可以使用 std::lower_bound(bt.begin(), bt.end( ), val) 就能在 O ( log ⁡ 2 N ) O\left(\log^2N\right) O(log2N) 时间内找到它)。

本项目被同步到 Github 中,内含详尽的中英文双语注释+README ,链接如下——

https://github.com/LiuZJ2019/BinaryIndexedTree_TemplateClass_And_Iterator

项目功能大致如下demo代码所示——

template<class T>
class BITree {
// 数据成员
private:
    std::vector<T> m_tree;     // Binary Indexed Tree
    std::vector<T> m_arr;
    ...
    
// 公有接口
public:
    // fun1: 随机访问
    virtual T get(size_t index) const;
    // fun2: 元素自增
    virtual void add(size_t index, T val);
    // fun3: 元素更新
    virtual void update(size_t index, T val);
    // fun4: sum[0:index]
    virtual T sum(size_t index) const;
    // fun5: sum[left:right]
    virtual T sum(size_t left, size_t right) const;
    // fun6: m_arr 的长度
    virtual size_t size() const;
    // fun7: resize
    virtual void resize(size_t siz);
    // fun8: 获取头迭代器, 可适配于 STL
    virtual iterator begin();
    // fun9: 获取尾迭代器, 可适配于 STL
    virtual iterator end();
    // fun10: 获取 m_arr
    virtual const std::vector<T> &get_arr() const;
    // fun11: 打印输出
    virtual void print(std::ostream &os) const;

// 随机访问迭代器
    struct sum_iterator : public std::iterator<...> {
        ...
    };
}

// 封装供流式IO调用
template<class T>
std::ostream& operator<< (std::ostream& os, const BITree<T> &packet);

详细代码

核心代码是 bitree.hbitree.cppbitree_test.hmain.cpp 是用来测试的。 bitree_test.h 给出了很多例子,用于展示如何使用代码,该项目的代码提供了详细的注释。

bitree.h

懒得动了,直接把代码复制过来——

/**
 * @file    bitree.h
 * @brief   Binary Indexed Tree
 * @author  间宫羽咲sama
 * @note    Template implementation of Binary Indexed Tree
 */

#ifndef INC_BINARYINDEXEDTREE_BITREE_H
#define INC_BINARYINDEXEDTREE_BITREE_H

#include <iostream>
#include <vector>

/**
 * @brief   Binary Indexed Tree
 * @details The English notes are as follows:
 *          Suppose 'L(x)' represents the lowest bit of 'x', '(101100) is the binary representation
 *          For example, if 'x == 44 == (101100)', then 'L(x) == 4 == (100)'
 *          Suppose 'T(i)' represents the sum of A from the interval of ( i - L(i), i ]
 *          For example, if 'i == (101100)', then 'i - L(i) == (101000)', and 'T(i)' is $\sum_{k=i-L(i)+1}^{i} {A_k}$
 *          In this case, The complexity of the update operation and the sum operation are both $O(log N)$
 *
 *       1. Sum: O(log N)
 *          Notice that: (0, (101100)] = ( (0), (100000) ] + ( (100000), (101000) ] + ( (101000), (101100) ]
 *          The maximum number of operations cannot exceed the number of bits, i.e. O(log N)
 *       2. Update: O(log N)
 *          If you want to modify index (101101), you only need to modify all the intervals containing it,
 *          the following numbers correspond to ( i - L(i), i ] will contain (101101):
 *              (101101), (101110), (110000), (1000000), (10000000), (100000000) ...
 *          But this number cannot exceed the length of the array, at most O(log N)
 *       3. Get: O(1)
 *          copy raw array to make the complexity of the get operation become O(1)
 *
 *
 *          The following content is the same as the English version, just a translated version.
 *          中文注释如下:
 *          记 01 串外面加括号代表数字的二进制表示, 如 (101100) .
 *          记 L(x) 为 x 的最低比特, 例如当 'x == 44 == (101100)' 时, 此时 'L(x) == 4 == (100)'
 *          记 'T(i)' 是 A 在区间 ( i - L(i), i ] 上的求和
 *          例如 'i == (101100)' 时, 有 'i - L(i) == (101000)', 此时 'T(i)' 是 $\sum_{k=i-L(i)+1}^{i} {A_k}$
 *          此时'更新'和'求和'操作的复杂度为 $O(log N)$
 *
 *       1. Sum: O(log N)
 *          注意到 (0, (101100)] = ( (0), (100000) ] + ( (100000), (101000) ] + ( (101000), (101100) ]
 *          最大操作数不超过数字的比特数, 显然为 O(log N)
 *       2. Update: O(log N)
 *          如果想修改 (101101), 只需要修改所有包含它的区间, 以下数字对应 ( i - L(i), i ] 会包含 (101101)
 *              (101101), (101110), (110000), (1000000), (10000000), (100000000) ...
 *          但是这个数字不可能超过数组长度, 最多只有 O(log N)
 *       3. Get: O(1)
 *          拷贝一份原始数组使得获取复杂度变成 O(1)
 */
template<class T>
class BITree {
public:
    // 1. This iterator can interact with algorithm 'lower_bound' in STL
    //    e.g. If you want to find the first element in the range [first, last) which does not compare less than 'val'
    //    You can use std::lower_bound(bt.begin(), bt.end(), val);
    class sum_iterator;
    using iterator = BITree<T>::sum_iterator;
private:
    // 2. Data member
    std::vector<T> m_tree;     // Binary Indexed Tree
    std::vector<T> m_arr;      // Let the get complexity become O(1), at the cost of doubling the space

protected:
    // 3. Only for the faster implementation of the constructor, it will not be called at other times
    virtual void add_impl(size_t index, T val);

public:
    // 4. Constructor
    BITree() = default;
    ~BITree() = default;
    BITree(const BITree &bt);
    BITree(BITree &&bt);
    BITree(const std::vector<T> &nums);
    BITree(std::vector<T> &&nums);


    // 5. All public interfaces: get, add, update, sum, resize
    /**
     * @brief   return A[index]
     * @note    plz ensure 0 <= index < m_arr.size()
     */
    virtual T get(size_t index) const;

    /**
     * @brief   Let A[index] increment by val
     * @note    plz ensure 0 <= index < m_arr.size()
     */
    virtual void add(size_t index, T val);

    /**
     * @brief   Let A[index] update by val, i.e. increment by (val - now_value)
     * @note    plz ensure 0 <= index < m_arr.size()
     */
    virtual void update(size_t index, T val);

    /**
     * @brief   Get the sum of the interval [0, index)
     * @note    plz ensure 0 <= index < m_arr.size() + 1
     */
    virtual T sum(size_t index) const;

    /**
     * @brief   Get the sum of the interval [left, right)
     *          if left > right, return the opposite of the sum of the interval [right, left)
     * @note    plz ensure 0 <= left, right < m_arr.size() + 1
     */
    virtual T sum(size_t left, size_t right) const;

    /**
     * @brief   return m_arr.size()
     */
    virtual size_t size() const;

    /**
     * @brief   resize, if siz < now_size, keep only [0, siz)
     */
    virtual void resize(size_t siz);

    /**
     * @brief   return iterator point to first item
     */
    virtual iterator begin();

    /**
     * @brief   return iterator point to last item
     */
    virtual iterator end();


    // 6. Make it easier to use
    /**
     * @brief   Get raw 'm_arr' for debug or other use
     */
    virtual const std::vector<T> &get_arr() const;


    // 7. For operator<< without the 'friend' keyword
    virtual void print(std::ostream &os) const;


    // 8. Iterator for search the sum, e.g. interact with lower_bound
    struct sum_iterator : public std::iterator<
            std::random_access_iterator_tag,    // iterator_category
            T,                                  // value_type
            size_t,                             // difference_type
            size_t,                             // pointer
            T>                                  // reference
    {
    private:
        BITree<T> *m_bt;
        size_t m_idx;

        using Self = sum_iterator;

    public:
        using iterator_category = std::random_access_iterator_tag;
        using value_type = T;
        using difference_type = size_t;
        using pointer = size_t;
        using reference = T;
        
    public:
        pointer get_idx() const {return m_idx;}

    public:
        sum_iterator(BITree<T> *bt, size_t idx) : m_bt(bt), m_idx(idx) {}
        sum_iterator(const Self &other) : m_bt(other.m_bt), m_idx(other.m_idx) {}
        sum_iterator(Self &&other) : m_bt(other.m_bt), m_idx(other.m_idx) {}
        sum_iterator& operator=(const Self &other) {m_bt = other.m_bt; m_idx = other.m_idx; return *this;}
        sum_iterator& operator=(Self &&other) {m_bt = other.m_bt; m_idx = other.m_idx; return *this;}
        reference operator*() const {return m_bt->sum(m_idx);}
        Self& operator++() {++ m_idx; return *this;}
        Self  operator++(int) {Self ret_val = *this; ++ (*this); return ret_val;}
        Self& operator+=(difference_type n) {m_idx += n; return *this;}
        Self  operator+(difference_type n) {return Self(m_bt, m_idx + n);}
        Self& operator--() {-- m_idx; return *this;}
        Self  operator--(int) {Self ret_val = *this; -- (*this); return ret_val;}
        Self& operator-=(difference_type n) {m_idx -= n; return *this;}
        Self  operator-(difference_type n) {return Self(m_bt, m_idx - n);}
        bool operator==(Self other) const {return m_idx == other.m_idx;}
        bool operator!=(Self other) const {return !(*this == other);}
        bool operator<(Self other)  const {return m_idx < other.m_idx;}
        bool operator<=(Self other) const {return m_idx <= other.m_idx;}
        bool operator>(Self other)  const {return m_idx > other.m_idx;}
        bool operator>=(Self other) const {return m_idx >= other.m_idx;}
        difference_type operator-(Self other) const {return m_idx - other.m_idx;}
    };
};

template<class T>
std::ostream& operator<< (std::ostream& os, const BITree<T> &packet);



// because of template class connot be divided by '.h' and '.cpp'
#include "bitree.cpp"

#endif //INC_BINARYINDEXEDTREE_BITREE_H

bitree.cpp

懒得动了,直接把代码复制过来——

/**
 * @file    bitree.h
 * @brief   Binary Indexed Tree
 * @author  间宫羽咲sama
 * @note    Template implementation of Binary Indexed Tree
 */

#ifndef INC_BINARYINDEXEDTREE_BITREE_CPP
#define INC_BINARYINDEXEDTREE_BITREE_CPP

#include "bitree.h"

template<class T>
BITree<T>::BITree(const BITree<T> &nums)
        : m_tree(nums.m_tree)
        , m_arr(nums.m_arr)
{
}

template<class T>
BITree<T>::BITree(BITree<T> &&nums)
        : m_tree(std::move(nums.m_tree))
        , m_arr(std::move(nums.m_arr))
{
}

template<class T>
BITree<T>::BITree(const std::vector<T> &nums)
        : m_tree(nums.size() + 1)
        , m_arr(nums)
{
    for (size_t i = 0; i < m_arr.size(); ++ i)
        add_impl(i, m_arr[i]);
}

template<class T>
BITree<T>::BITree(std::vector<T> &&nums)
        : m_tree(nums.size() + 1)
        , m_arr(std::move(nums))
{
    for (size_t i = 0; i < m_arr.size(); ++ i)
        add_impl(i, m_arr[i]);
}

template<class T> T
BITree<T>::get(size_t index) const {
    return m_arr[index];
}

template<class T> void
BITree<T>::add_impl(size_t index, T val) {
    for (++ index; index < m_tree.size(); index += index & (-index))
        m_tree[index] += val;
}

template<class T> void
BITree<T>::add(size_t index, T val) {
    m_arr[index] += val;
    add_impl(index, val);
}

template<class T> void
BITree<T>::update(size_t index, T val) {
    add(index, val - get(index));
}

template<class T> T
BITree<T>::sum(size_t index) const {
    T ans = 0;
    for (; index > 0; index &= index - 1)
        ans += m_tree[index];
    return ans;
}

template<class T> T
BITree<T>::sum(size_t left, size_t right) const {
    return sum(right) - sum(left);
}

template<class T> const std::vector<T> &
BITree<T>::get_arr() const {
    return m_arr;
}

template<class T> size_t
BITree<T>::size() const {
    return m_arr.size();
}

template<class T> void
BITree<T>::resize(size_t siz) {
    if (siz == m_arr.size())
        return;

    m_arr.resize(siz);
    // erase all
    m_tree.resize(0);
    m_tree.resize(siz + 1);
    for (size_t i = 0; i < m_arr.size(); ++ i)
        add_impl(i, m_arr[i]);
}

template<class T> typename BITree<T>::iterator
BITree<T>::begin() {
    return BITree<T>::iterator(this, 0);
}

template<class T> typename BITree<T>::iterator
BITree<T>::end() {
    return BITree<T>::iterator(this, size());
}

template<class T> void
BITree<T>::print(std::ostream &os) const {
    auto end_1 = m_arr.end() - 1;
    os << "[";
    for (auto it = m_arr.begin(); it != m_arr.end(); ++ it)
        os << *it << (it != end_1 ? "," : "");
    os << "]";
}

template<class T> std::ostream &
operator<< (std::ostream &os, const BITree<T> &packet) {
    packet.print(os);
    return os;
}

#endif //INC_BINARYINDEXEDTREE_BITREE_CPP

bitree_test.h

懒得动了,直接把代码复制过来——

#ifndef BINARYINDEXEDTREE_BITREE_TEST_H
#define BINARYINDEXEDTREE_BITREE_TEST_H

#include "bitree.h"
#include <algorithm>

void BITree_test1() {
    std::cout << std::endl << "---------- BITree_test1 begin ----------" << std::endl;

    std::cout << std::endl << "---------- Part-1: Construct ----------" << std::endl;
    std::vector<double> vec1{2.71828, 3.14159, 1.14514, 1.19198, 0.57721};
    size_t siz = vec1.size();
    BITree<double> bt1(std::move(vec1));

    // expect output: bt1 = [2.71828,3.14159,1.14514,1.19198,0.57721]
    std::cout << "bt1 = " << bt1 << std::endl;
    // expect output: bt1 = vec1.size() = 0
    std::cout << "vec1.size() = " << vec1.size() << std::endl;

    BITree<double> bt2(bt1);

    // expect output: bt2 = [2.71828,3.14159,1.14514,1.19198,0.57721]
    std::cout << "bt2 = " << bt2 << std::endl;

    BITree<double> bt3(std::move(bt2));

    // expect output: bt2 = []
    std::cout << "bt2 = " << bt2 << std::endl;
    // expect output: bt2 = [2.71828,3.14159,1.14514,1.19198,0.57721]
    std::cout << "bt3 = " << bt3 << std::endl;

    std::cout << std::endl << "---------- Part-2: Sum ----------" << std::endl;

    // expect output: bt1.sum(0) = 0
    //                bt1.sum(1) = 2.71828
    //                bt1.sum(2) = 5.85987
    //                bt1.sum(3) = 7.00501
    //                bt1.sum(4) = 8.19699
    //                bt1.sum(5) = 8.7742
    for (size_t i = 0; i <= siz; ++ i)
        std::cout << "bt1.sum(" << i << ") = " << bt1.sum(i) << std::endl;

    std::cout << std::endl << "---------- Part-3: Add ----------" << std::endl;

    bt1.add(2, 1);

    // expect output: bt1 = [2.71828,3.14159,2.14514,1.19198,0.57721]
    std::cout << "bt1 = " << bt1 << std::endl;

    // expect output: bt1.sum(0) = 0
    //                bt1.sum(1) = 2.71828
    //                bt1.sum(2) = 5.85987
    //                bt1.sum(3) = 8.00501
    //                bt1.sum(4) = 9.19699
    //                bt1.sum(5) = 9.7742
    for (size_t i = 0; i <= siz; ++ i)
        std::cout << "bt1.sum(" << i << ") = " << bt1.sum(i) << std::endl;

    std::cout << std::endl << "---------- Part-4: ReSize ----------" << std::endl;

    bt1.resize(siz - 1);

    // expect output: bt1 = [2.71828,3.14159,2.14514,1.19198]
    std::cout << "bt1 = " << bt1 << std::endl;

    // expect output: bt1.sum(0) = 0
    //                bt1.sum(1) = 2.71828
    //                bt1.sum(2) = 5.85987
    //                bt1.sum(3) = 8.00501
    //                bt1.sum(4) = 9.19699
    for (size_t i = 0; i <= siz - 1; ++ i)
        std::cout << "bt1.sum(" << i << ") = " << bt1.sum(i) << std::endl;

    bt3.resize(siz + 1);

    // expect output: bt3 = [2.71828,3.14159,2.14514,1.19198,0.57721, 0]
    std::cout << "bt3 = " << bt3 << std::endl;

    // expect output: bt3.sum(0) = 0
    //                bt3.sum(1) = 2.71828
    //                bt3.sum(2) = 5.85987
    //                bt3.sum(3) = 7.00501
    //                bt3.sum(4) = 8.19699
    //                bt3.sum(5) = 8.7742
    //                bt3.sum(6) = 8.7742
    for (size_t i = 0; i <= siz + 1; ++ i)
        std::cout << "bt3.sum(" << i << ") = " << bt3.sum(i) << std::endl;

    std::cout << std::endl << "---------- Part-5: Iterator ----------" << std::endl;

    BITree<double>::iterator it3 = std::lower_bound(bt1.begin(), bt1.end(), 5);
    // expect output: std::lower_bound(bt1.begin(), bt1.end(), 5) = 5.85987
    std::cout << "std::lower_bound(bt1.begin(), bt1.end(), 5) = " << *it3 << std::endl;

    ++ it3;
    // expect output: std::lower_bound(bt1.begin(), bt1.end(), 5) = 8.00501
    std::cout << "*++it = " << *it3 << std::endl;

    it3 = std::lower_bound(bt1.begin(), bt1.end(), 5.85988);
    // expect output: std::lower_bound(bt1.begin(), bt1.end(), 5.85988) = 8.00501
    std::cout << "std::lower_bound(bt1.begin(), bt1.end(), 5.85988) = " << *it3 << std::endl;

    // expect output: bt1.sum(0) = 0
    //                bt1.sum(1) = 2.71828
    //                bt1.sum(2) = 5.85987
    //                bt1.sum(3) = 8.00501
    //                bt1.sum(4) = 9.19699
    for (auto it = bt1.begin(); it <= bt1.end(); ++ it)
        std::cout << "bt1.sum(" << it.get_idx() << ") = " << *it << std::endl;

    // expect output: bt1.sum... = 0
    //                bt1.sum... = 2.71828
    //                bt1.sum... = 5.85987
    //                bt1.sum... = 8.00501
    for (auto prefix_sum : bt1)
        std::cout << "bt1.sum... = " << prefix_sum << std::endl;

}

#endif //BINARYINDEXEDTREE_BITREE_TEST_H

main.cpp

懒得动了,直接把代码复制过来——

#include <iostream>
#include "bitree.h"
#include "bitree_test.h"

int main() {
    BITree_test1();
    return 0;
}

总结

总的来说,代码本身其实是次要的,主要是靠这个机会练习了下C++代码的写作规范,比如拷贝构造、移动构造、默认构造怎么写,怎么写一个符合STL库规范的迭代器。

在写代码的过程中也踩了一些老生常谈的坑,比如模板函数不能拆成 .h 文件和 .cpp 文件,因为模板的声明和实现是一体的,只能在 .h 文件末尾 include 对应的 .cpp 文件才行。写迭代器的过程也很头疼,不过写完一次就明白多了。总之这个工程本身其实倒没啥,主要是靠这个机会实操了一下这些东西。

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值