通用线段树的写法

class NumArray {
public interface Merger<E>
{
    E merge(E a, E b);
}
public class SegmentTree<E>
{
    /*
    使用一个数组表示区间.
    首先,用户可能要获取区间内某一个的元素,或者获取区间的某一个属性,我们在线段树中创建数组,作为区间数组的的副本,用于给出区间数组的某些属性;
    其次,我们想将data数组内的元素(arr数组区间传递进来的)组织成为一个线段树,那么还需要创建一个tree,大小为 4size
    */

    private E[] data;
    //tree结点虽然表示的是一个区间,但是它里面存储的实际上是区间内数据按一定方式计算得到的值。
    // 存储区间只是一种方式,目的还是得到区间内数字通过计算得到的值!
    private E[] tree;
    private Merger<E> merger;

    //初始化的时候传入一个区间数组
    public SegmentTree(E[] arr , Merger<E> merger)
    {
        this.merger = merger;

        data = (E[])new Object[arr.length];//使用arr初始化data
        for (int i = 0; i < arr.length ; i++)
        {
            data[i] = arr[i];
        }

        tree = (E[])new Object[arr.length * 4];

        //调用 buildSegmentTree() 方法,完成线段树数组tree每一个结点 data数组区间的确定!
        //这样线段树一初始化,它的线段树数组中每一个结点所表示的 data数组的区间也确定了!
        buildSegmentTree(0 , 0 , data.length-1);//线段树从0位置开始赋予区间,此时区间为data最大区间:0-data.length-1
    }

    //线段树所关注的区间有多少个元素
    public int getSize()
    {
        return data.length;
    }

    public E get(int index)
    {
        if(index < 0 || index >= data.length)
            throw new IllegalArgumentException("Index is illegal.");
        return data[index];//数据可以获取区间数组的某一个下标的数据
    }

//---------------------------辅助函数:获取tree数组中某一下标的结点的左右孩子结点在数组的下标
    // 返回满二叉树的数组表示中,一个索引所表示的元素的左孩子节点的索引
    private int leftChild(int index){
    return 2*index + 1;
}

    // 返回满二叉树的数组表示中,一个索引所表示的元素的右孩子节点的索引
    private int rightChild(int index){
        return 2*index + 2;
    }

//--------------------------构建
    /**
     在treeIndex的位置创建表示区间[l...r]的线段树结点,即确定 线段树数组tree 在下标 treeIndex 位置所表示的 在data数组 中的区间。

     三个元素:1-要创建的线段树结点在tree数组中的索引;2,3-当前线段树结点表示的 data数组区间的左右端点的索引,注意这个索引是在data数组中的索引
     即在tree数组中,下标为 treeIndex,它在 data数组中表示的区间范围是: l-r

     这个方法设置为私有的,并且我们会从线段树数组 tree的0位置开始,不断确定线段树数组的每一个位置所表示的 data数组的区间,
     直到data数组区间的长度为1,此时遍历到线段树的叶子结点。这个过程中线段树数组可能有很多结点没有表示的区间!

     我们在线段树类的构造方法中就调用 buildSegmentTree() 方法,完成线段树数组每一个结点 data数组区间的确定!
     */
    private void buildSegmentTree(int treeIndex, int l, int r)
    {
        if(l == r)
        {
            //此时遍历到线段树的叶子结点,因为只有叶子结点表示的 data数组的区间长度为1,那么直接给叶子结点赋值为 data[l]
            tree[treeIndex] = data[l];
            //注意,到达叶子结点就return结束函数,否则下面继续计算,会导致出现大雨tree数组下标的子结点,出现数组下标越界
            return;
        }

        //下面我们继续给当前结点的左右孩子结点设置他们在 data数组中的区间

        //首先求左右孩子结点在tree数组中的下标
        int leftNodeIndex = leftChild(treeIndex);
        int rightNodeIndex = rightChild(treeIndex);

        //求区间中间的分割值
        // int mid = (l + r) / 2;//当l与r很大的时候,可能产生整形溢出的问题
        int mid = l + (r - l) / 2;

        //递归调用buildSegmentTree设置左右孩子结点在 data数组的区间
        buildSegmentTree(leftNodeIndex , l , mid);
        buildSegmentTree(rightNodeIndex , mid+1 , r);//这种设置方法,在区间长度为奇数的时候,左孩子结点区间比右孩子结点区间长度大1

        /**
         下面考虑 线段树tree treeIndex位置的值,这个值与我们的业务相关。
         我们综合当前结点左右孩子结点在data数组中的线段信息,就可以得到当前结点的线段信息!因为当前结点线段的区间,等于它左右孩子线段区间的和!

         我们创建一个融合接口,然后用户可以创建实现这个融合接口的类,自定义它融合左右孩子区间的含义,将这个自定义的对象传递进来,
         那么此时我们tree数组中 treeIndex 位置的值,就可以根据融合的规则,融合左右孩子结点的值得到!
         public interface Merger<E>
         {
            E merge(E a , E b);
         }
         */
        tree[treeIndex] = merger.merge(tree[leftNodeIndex], tree[rightNodeIndex]);
    }

//---------------------------------------查询方法

    // 返回区间[queryL, queryR]的值,这个区间的值是由我们传入的Merger实现类定义的(可能是区间元素和,区间最大值等!)
    public E query(int queryL , int queryR)
    {
        //将参数不合理的情况排除
        if(queryL < 0 || queryL >= data.length ||
                queryR < 0 || queryR >= data.length || queryL > queryR)
            throw new IllegalArgumentException("Index is illegal.");

        return query(0 , 0 , data.length-1 , queryL , queryR);
    }

    /**
     * 在以treeIndex为根的线段树中[l...r]的范围里,搜索区间[queryL...queryR]的值,
     * 即对于线段树数组下标为 treeIndex 的结点,它在data数组的区间是 [l...r],查询在 [l...r] 区间内 区间[queryL...queryR]的值
     *
     * 对于每一个 tree 数组结点,我们都传入它在 data数组中所表示的区间,我们其实可以将他们封装成为一个线段树的结点类,
     * 结点类包含当前线段树结点 在tree数组的下标:treeIndex,以及在data数组 区间的范围: [l,r]!
     * 理解:int treeIndex, int l, int r:这三个参数都在表示线段树结点的信息!!!
     */
    //我们向下将最开始的查询区间拆分,直到某一个小区间满足当前 tree结点的区间,我们直接将tree结点的值返回即可!
    private E query(int treeIndex, int l, int r, int queryL, int queryR)
    {
        //小区间满足当前 tree结点的区间,我们直接将tree结点的值返回即可!
        //当然,tree结点的值由其左右孩子结点区间值融合而来,这里我们在 buildSegmentTree() 方法中已经定义好!
        if(l == queryL && r == queryR)
            return tree[treeIndex];

        //如果当前tree结点的区间大于 查询的小区间 (小于的情况上面方法中被我们排除了!),查询tree结点左右孩子结点的区间

        //首先求左右孩子结点在tree数组中的下标
        int leftNodeIndex = leftChild(treeIndex);
        int rightNodeIndex = rightChild(treeIndex);

        //求当前tree结点分割为左右孩子区间的时候,区间中间的分割值:左区间:[l,mid],右区间:[mid+1 , r]
        int mid = l + (r - l) / 2;

        if(queryL>=mid+1) //当查询的区间全部在右孩子结点的区间
            return query(rightNodeIndex , mid+1 , r , queryL , queryR);
        else if(queryR<=mid) //当查询的区间全部在左孩子结点的区间
            return query(leftNodeIndex , l , mid , queryL , queryR);

        //否则,我们得将查询区间分为2部分,分别在左右孩子区间查找
        E leftResult = query(leftNodeIndex , l , mid , queryL , mid);
        E rightResult = query(rightNodeIndex , mid+1 , r , mid+1 , queryR);

        //按用户定义的融合方法融合2个值并返回!
        return merger.merge(leftResult,rightResult);
    }

//--------------------------------更新

    /**
     * 将index位置的值,更新为e。
     * 1)首先,我们将存储数据的 data 数组 index位置的值更新为val;
     * 2)其次,表示线段树结点的 tree 数组的所有叶子结点,存储的值也是data数组的值,我们需要找到 tree数组 表示的线段树相应的叶子结点,
     *  将这个叶子结点的值进行更新。
     *  同时,由于这个叶子结点的值更新,而它的父亲结点的区间包含这个叶子结点的值,它的父亲结点的值由它的区间值计算得来,
     *  它的父亲结点的区间也由左右孩子区间融合得来,那么也可以融合左右孩子的值得到父亲结点的值
     */
    public void set(int index, E e)
    {
        if (index < 0 || index >= data.length)
            throw new IllegalArgumentException("Index is illegal");
        //更新data数组
        data[index] = e;
        //调用新的set方法,更新tree数组
        set(0 , 0 , data.length-1 , index , e);
    }

    // 在以treeIndex为根的线段树中更新index的值为e
    private void set(int treeIndex, int l, int r, int index, E e)
    {
        //index必然在data数组的范围之内(前面判断),这里不需要判断index是否合法

        if(l == r)
        {//当l=r,说明找到线段树的叶子结点,该结点的区间就是[index,index],该结点值就是data[index],更新这个结点的值
            tree[treeIndex] = e;
            return;//注意,这里找到叶子结点后,必须终止函数,否则继续向下计算会导致数组角标越界
        }

        //没找到tree数组中对应data数组的区间是 [index,index] 的结点,则继续向下缩小区间寻找
        int leftNodeIndex = leftChild(treeIndex);
        int rightNodeIndex = rightChild(treeIndex);

        int mid = l + (r - l) / 2;

        if(index<=mid)
            set(leftNodeIndex , l , mid , index ,e);//向左子树寻找
        else
            set(rightNodeIndex , mid+1 , r , index ,e);//向右子树寻找

        //由于线段树叶子结点的值得更新,必然会影响父亲结点的区间内的值,从而影响夫妻群结点的值
        //那么我们在更新完当前结点的左孩子区间或者是右孩子区间后(必然有一个区间受影响),将左右孩子的值重新融合,更新当前tree结点的值
        tree[treeIndex] = merger.merge(tree[leftNodeIndex] , tree[rightNodeIndex]);//从线段树的下面往上思考这个更新过程
    }


//----------------------toString方法
    @Override
    public String toString(){
        StringBuilder res = new StringBuilder();
        res.append('[');
        for(int i = 0 ; i < tree.length ; i ++){
            if(tree[i] != null)
                res.append(tree[i]);
            else
                res.append("null");

            if(i != tree.length - 1)
                res.append(", ");
        }
        res.append(']');
        return res.toString();
    }
}

    private SegmentTree<Integer> segmentTree;

    /** O(n):只初始化1次 */
    public NumArray(int[] nums)
    {
        //首先判断传入的数组合法
        if(nums != null && nums.length>0)
        {
            Integer[] data = new Integer[nums.length];
            for (int i = 0; i < nums.length ; i++)
            {
                data[i] = nums[i];
            }
            /**  使用data数组构造线段树,线段树每个结点都代表data数组的一个区间(其实线段树结点存储的是区间按一定规则融合的值)。同时,定义线段树区间融合的方式  */
            segmentTree = new SegmentTree<Integer>(data , (o1 , o2) -> o1+o2 );//同样,定义线段树数组tree结点的值是 data区间结点值和
        }
    }

    public void update(int i, int val) 
    {
        //同样判断线段树是否初始化成功
        if(segmentTree == null)
            throw new IllegalArgumentException("Segment Tree is null");
        segmentTree.set(i , val);
    }
    
    /** O(logn) */
    public int sumRange(int i, int j)
    {
        //有可能nums数组不合法,初始化的时候没有初始化 SegmentTree,判断一下 SegmentTree是否合法
        if(segmentTree == null)
            throw new IllegalArgumentException("Segment Tree is null");

        return segmentTree.query(i , j);//返回查询结果即可!结果就是区间和。每一次查询都是 O(n) 复杂度
    }
}

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

水之积也不厚,则其负大舟也无力

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值