lkb 的小屋

start again. //小蒟蒻也有大梦想

[SMOJ1868]bst计数

这题一看就是个数据结构题,不过具体做法还是很巧妙的。

50%:
简直是送分数据,伪代码都给出了,直接模拟构一棵 BST 就好了。
时间复杂度上限:众所周知 BST 的平衡性十分不可靠,退化成链的时候当然是 O(n2)
ps. 不是很懂比赛的时候为什么会有 40 分的,大概是数组没开够?
pps. 打对暴力还是很重要的,一方面保证即使想不到正解也不至于“未提交”,另一方面可以拿来对拍
代码:

#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>

using namespace std;

const int MAXN = 3e5 + 100;

struct Tnode {
    Tnode *child[2];
    int val;

    Tnode (int v = 0) : val(v) { child[0] = child[1] = NULL; }
} nodes[MAXN], *current;


int C;
struct BST {
    Tnode *root;
    BST () : root(NULL) {}

    Tnode *get_point(int v) { //当然是选择内存池机制
        (*current) = Tnode(v);
        return current ++;
    }

    void insert_val(Tnode *&cur, int num) { //题目的伪代码硕长,其实三行就能搞定的事情
        if (!cur) { cur = get_point(num); return ; }
        ++ C;
        insert_val(cur -> child[num > cur -> val] , num);
    }
} lkb_bst;

int N;

int main(void) {
    freopen("1868.in", "r", stdin);
    freopen("1868.ans", "w", stdout);
    scanf("%d", &N); current = nodes;
    for (int i = 0; i < N; i++) {
        int X; scanf("%d", &X);
        lkb_bst.insert_val(lkb_bst.root, X);
        printf("%d\n", C);
    }
    return 0;
}


70%:
在讲正解之前,先简单记录一下我比赛时的 yy 想法。(虽然似乎没有什么价值)
在推了一些例子之后,我欣喜地发现,其实向 BST 插入结点的过程,可以拆分成一条条链,且这些链上的结点值都具有单调性。for example:

这样,每次对于一个结点,可以在对应链上找到第一个小于 / 大于待插入值的结点,然后再进入下一条链。找的过程显然可以二分搞一下。
不得不承认我这种水法的初衷只是为了退化的情况不要被卡得太严重,结果多拿了 20 分,已经很满足了。(后面常数硕大,T 掉了)

虽然很菜,也贴一下代码吧。

#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
//#include <ctime>
#include <iostream>

using namespace std;

const int MAXN = 3e5 + 100;
const int MAXD = 20;
const int MOD = 1e5 + 7;
int _pow2[MAXN]; //2^i % MOD

struct Tnode {
    Tnode *child[2][MAXD], *par[MAXD]; //child[0][i] 为第 2^i 代左儿子,1 为右;par[i] 为第 2^i 代父亲
    int val, dep, maxd[2]; int id; //maxd[0] 为使 child[0][i] 不为空的最大的 i,1 同理

    Tnode (int v = 0, int d = 0, int x = 0) : val(v), dep(d), id(x) {
        maxd[0] = maxd[1] = -1; //构造函数会导致常数硕大无比,但是不初始化又会爆炸
        for (int i = 0; i < MAXD; i++) child[0][i] = child[1][i] = par[i] = NULL;
    }
} nodes[MAXN], *current;

long long C;
struct BST {
    Tnode *root;
    BST () : root(NULL) {}

    Tnode *get_point(int v, int d, int x) {
        (*current) = Tnode(v, d, x);
        return current ++;
    }

    Tnode *query(Tnode *r, int d) { //询问 r 的第 d 代父亲
        for (int p = 0; d; p++, d >>= 1) if (d & 1) r = r -> par[p];
        return r;
    }

    Tnode *find_node(Tnode *l, Tnode*r, int v, int d) {//在 (l, r] 链上找深度最小的大于或小于 v 的结点,d=0 时在左链中找,1 为右链
        while (l -> dep + 1 < r -> dep) {
            Tnode *mid = query(r, r -> dep - l -> dep >> 1);
            if ((!d && mid -> val < v) || (d && mid -> val > v)) r = mid; else l = mid;
        }
        return r;
    }

    Tnode *calc_far(int d, Tnode *r) { //求 r 在 d 方向上的链的最远儿子(d含义同上)
        for (; r -> maxd[d] != -1; r = r -> child[d][r -> maxd[d]])
            ;
        return r;
    }

    bool is_chain(int d, Tnode *u, Tnode *v) { //判断 u、v 是否在同一条链上(用到了结点编码,常数硕大)
        int de = v -> dep - u -> dep;
        return !d && u -> id * _pow2[de] % MOD == v -> id || d && ((u -> id + 1) * _pow2[de] - 1) % MOD == v -> id;
    }

    long long insert_val(Tnode *&cur, int num) { //从 cur 结点开始插入 num 时 C 的增量
        int t = num > cur -> val;
        Tnode *p = find_node(cur, calc_far(t, cur), num, t);
//      printf("%d %d\n", num, p -> val);
        if (p != cur) return p -> dep - cur -> dep + insert_val(p, num); //可以进入下一条链,对于此链上的中间部分直接加答案
        else { //就是自己的儿子
            Tnode *ch = cur -> child[t][0];
            ch = get_point(num, cur -> dep + 1, ((cur -> id << 1) + t) % MOD);
            p = ch -> par[0] = cur;
            for (int i = 0; p -> par[i]; p = p -> par[i++]) //倍增
                ch -> par[i + 1] = p -> par[i];
            for (int i = 0; ch -> par[i]; i++) {
                if (!is_chain(t, ch -> par[i], ch)) break; //不断向上更新父亲,直到不在同一条链上
                if (!t && num < ch -> par[i] -> val) { ch -> par[i] -> child[0][i] = ch; ch -> par[i] -> maxd[0] = i; }
                else if (t && num > ch -> par[i] -> val) { ch -> par[i] -> child[1][i] = ch; ch -> par[i] -> maxd[1] = i; }
            } //需要分类讨论
            return 1;
        }
    }

    void debug_output(Tnode *root) {
        putchar('(');
        if (root -> child[0][0]) debug_output(root -> child[0][0]);
        putchar(')');
        printf("%d ", root -> val);
        putchar('(');
        if (root -> child[1][0]) debug_output(root -> child[1][0]);
        putchar(')');
    }

    void debug_output2() {
        for (Tnode *i = nodes; i != current; i++) {
            printf("v = %d, par[] = {", i -> val);
            for (int j = 0; i -> par[j]; j++) printf("%d, ", i -> par[j] -> val);
            printf("}, child[0] = {");
            for (int j = 0; i -> child[0][j]; j++) printf("%d, ", i -> child[0][j] -> val);
            printf("}, child[1] = {");
            for (int j = 0; i -> child[1][j]; j++) printf("%d, ", i -> child[1][j] -> val);
            puts("}");
        }
    }
} lkb_bst;

int N;

inline int readint() {
    char c = getchar();
    while (c < '0' || c > '9') c = getchar();
    int num = 0;
    while (c >= '0' && c <= '9') {
        num = (num << 3) + (num << 1) + (c - '0');
        c = getchar();
    }
    return num;
}

inline void writelonglong(long long x) {
    static int t[12];
    int i = 0;
    while (x) {
        t[i++] = x % 10;
        x /= 10;
    }
    while (i--) putchar(t[i] + '0');
    putchar('\n');
}

int main(void) {
//  int start = clock();
    freopen("1868.in", "r", stdin);
    freopen("1868.out", "w", stdout);
    N = readint(); current = nodes; memset(nodes, 0, sizeof nodes);
    _pow2[0] = 1; for (int i = 1; i <= N; i++) _pow2[i] = (_pow2[i - 1] << 1) % MOD;
    int X; X = readint(); puts("0");
    lkb_bst.root = lkb_bst.get_point(X, 0, 1);
    for (int i = 1; i < N; i++) {
        X = readint();
        writelonglong(C += lkb_bst.insert_val(lkb_bst.root, X));
//      lkb_bst.debug_output(lkb_bst.root); putchar('\n');
//      lkb_bst.debug_output2();
    }
//  printf("time used:%d ms", clock() - start);
    return 0;
}


100%:
现在来讲正解。
根据提供的伪代码可以看出来,对于每一个被插入的新结点 x,它对 C 的增加产生的所贡献,其实就是从根结点到 x 将来的父亲路上经过的节点个数。而找父亲的过程中,每次必然向左或向右到达下一层,深度加 1。也就意味着 C 增加的其实就是 x 最终的深度。
则问题转化为:对于一棵给定的 BST,如何快速求出 x 在 BST 中的最终深度?

直接考虑求 x 的深度显然是不好求的。不妨想想,平时我们如何处理一棵给定的树中各结点的深度?
显然,直接从根结点向下跑 DFS(同时标记其深度为 0),每个结点的深度就是其父亲的深度 + 1。
这里其实同理,因为在每个结点被插入时,我们需要求出其深度。也就意味着求得后,可以在插入时将其记录下来,作为结点的附加信息。
这样,在某个结点被插入前,BST 中各结点都已经知道其深度了。只要找到 x 的父亲,将其深度加上 1,就是 x 的深度。
则问题转化为:对于一棵给定的 BST,如何快速求出 x 的父亲?

有如下定理:在一棵无重复结点的 BST 中(以下分析均以此为前提),x 的前驱 u 和后继 v 中深度较大的就是 x 的父亲,且 uv 一定存在祖先关系。
为什么一定是前驱或后继?可以这样理解:BST 的插入,其实就是从根结点出发,一步步“逼近”x 的过程(手动模拟有助理解),而最接近的无非前驱后继。

定理的证明如下:

  • 首先证明 uv 之间一定存在祖先关系,可以用反证法。假如 uv 之间并没有祖先关系,那么它们必然存在一个 LCA 结点 y,又因为没有重复结点,所以可以分为两种情况讨论:

    • u<x<y<v,则这个 y 比所求得的 v 更接近 x 的值,这与后继的定义“大于 x 的结点中最小的”相矛盾;
    • 类似地,u<y<x<v,则这个 y 比所求得的 u 更接近 x 的值,这与前驱的定义“小于 x 的结点中最大的”相矛盾。

    联立可知,uv 之间必然存在祖先关系。

  • 如何说明“uv 中深度较大的就是 x 的父亲”?不妨记结点 i 的深度为 dep(i),同样可以进行分类讨论:

    • dep(u)<dep(v) 时,显然 vu 的右子树中,则 u 的右儿子非空。根据前驱的定义可知 u<x,因此 x 会被插入到 u 的右子树中,但 u 的右儿子非空,所以 x 只能成为 v 的左儿子。
    • dep(v)<dep(u) 时,显然 uv 的左子树中,则 v 的左儿子非空。根据后继的定义可知 x<v,因此 x 会被插入到 v 的左子树中,但 v 的左儿子非空,所以 x 只能成为 u 的右儿子。

可能会想到一种情况:会不会对于 uv 当中深度较大的结点,x 将要插入的位置非空,导致没有位置被插入?这种情况当然是不成立的。

证明:

  • 不妨假设 dep(u)<dep(v)v 的左子树不为空。根据 BST 性质可知,对于 v 的左子树中任意结点 k 都有 k<v
  • 同时,ku 的右子树中,则 u<k。整理得 u<k<v
  • 此时还是可以延续上面的思路,分类讨论
    • x<k,则 u<x<k<v,与后继的定义相矛盾
    • k<x,则 u<k<x<v,与前驱的定义相矛盾。
  • 因此,当 dep(u)<dep(v) 时,v 的左子树必定为空。同理可以证明 dep(v)<dep(u) 时,u 的右子树必然为空。换言之,x 被插入后必然是叶子结点。

这样,则问题转化为如何在 BST 中查询到 uv,从而得到它们的深度,比较后将 x 插入。传统的 BST 不平衡,肯定不能直接查,那么……
想到了什么?没错,就是平衡树。可以用 Treap 或者 Splay 之类的(我不会)平衡树维护这棵 BST,旋转之后形态的改变并不重要,因为不影响对前驱和后继的查询,而我们所关注的只是它们的深度,是结点的附加信息,旋转没有任何关系。

到这里,做法应该已经呼之欲出了:用 Treap 保存输入的元素,给每个结点加上附加值深度。求出 uv 之后,两者深度的最大值就是 C 的增量,然后将新输入的元素插入到 Treap 里面即可。
时间复杂度?当然是 O(nlog2n) 的了。空间复杂度则是 O(n)

总结一下,这题里面多次对问题进行了转化,在证明的过程中用到了分类讨论的思想。做出来之后再来看,这题目还是相当有意思的,不失为一道好题。
不过也应该反思,为什么自己在分析问题的过程中没有意识到求前驱和后继?
问题求解的过程应该是步步递进的,主要原因还是没有抓住 BST 特有的性质,也就是对基础知识掌握得不够透彻,认识得不够深入。
分析和转化问题的能力,还是要在大量的练习和总结反思当中才能得到提高。一旦懈怠,大脑很快就会退化。

参考代码:

//1868.cpp
#include <algorithm>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>

using namespace std;

const int MAXN = 3e5 + 100;

struct Tnode {
    Tnode *child[2];
    int dep, fix, val; //顺便维护下深度
    Tnode (int d = 0, int v = 0) : dep(d), fix(rand()), val(v) {
        child[0] = child[1] = (Tnode*)0;
    }
} nodes[MAXN], *current;

struct Treap {
    Tnode *root;
    Treap () : root(NULL) {}

    Tnode *get_point(int d, int v) {
        (*current) = Tnode(d, v);
        return current ++;
    }

    void rotate(Tnode *&cur, int t) {
        Tnode *ch = cur -> child[t ^ 1];
        cur -> child[t ^ 1] = ch -> child[t];
        ch -> child[t] = cur;
        cur = ch;
    }

    void insert_val(Tnode *&cur, int value, int depth) {
//      printf("insert_val %d\n", value);
        if (!cur) cur = get_point(depth, value);
        else {
            int t = value > cur -> val;
            insert_val(cur -> child[t], value, depth);
            if (cur -> child[t] -> fix < cur -> fix) rotate(cur, t ^ 1);
        }
    }

    Tnode *query_pred(Tnode *cur, int value, Tnode *best) { //前驱
//      printf("query_pred %d\n", value);
        if (!cur) return best;
        else if (value < cur -> val) return query_pred(cur -> child[0], value, best);
        else return query_pred(cur -> child[1], value, cur);
    }

    Tnode *query_succ(Tnode *cur, int value, Tnode *best) { //后继
//      printf("query_succ %d\n", value);
        if (!cur) return best;
        else if (value > cur -> val) return query_succ(cur -> child[1], value, best);
        else return query_succ(cur -> child[0], value, cur);
    }
} lkb_treap;

int N;

int main(void) {
    freopen("1868.in", "r", stdin);
    freopen("1868.out", "w", stdout);
    scanf("%d", &N); long long C = 0; current = nodes;
    for (int i = 0; i < N; i++) {
//      printf("i=%d\n", i);
        int X; scanf("%d", &X);
        Tnode *pred = lkb_treap.query_pred(lkb_treap.root, X, (Tnode*)0) ;
        Tnode *succ = lkb_treap.query_succ(lkb_treap.root, X, (Tnode*)0);
        int d = max(pred ? pred -> dep : 0, succ ? succ -> dep : 0); //取深度更大的,但要注意可能不存在前驱或后继,所以要特判一下
        printf("%lld\n", C += d);
        lkb_treap.insert_val(lkb_treap.root, X, d + 1);
    }
    return 0;
}


阅读更多
版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/u013686535/article/details/77248761
个人分类: 解题报告 SMOJ Treap
想对作者说点什么? 我来说一句

没有更多推荐了,返回首页

不良信息举报

[SMOJ1868]bst计数

最多只允许输入30个字

加入CSDN,享受更精准的内容推荐,与500万程序员共同成长!
关闭
关闭