点分治专题

简介

如果处理“所有经过某一个顶点的链对答案的贡献”的时间复杂度为 O ( n ) O(n) O(n)或者 O ( n l o g n ) O(nlogn) O(nlogn),那么运用点分治的思想可以把问题规模降为 O ( n l o g n ) O(nlogn) O(nlogn) O ( n l o g 2 n ) O(nlog^2n) O(nlog2n),而非暴力枚举顶点计算答案的 O ( n 2 ) O(n^2) O(n2)

所以说,点分治是一种在树上统计合法链个数的思想。显而易见的,对于当前顶点 x x x,任意一条链要么经过 x x x,要么不经过 x x x。于是我们只需要计算那些经过 x x x的链,剩下不经过 x x x的链则统统放在子树中递归计算,此为分治。

那么,怎么分治?按照常规的递推思路来考虑,处理完当前顶点后,递归进入所有与顶点相邻的顶点并分别处理它们?不行,因为很容易可以找到一条退化成链的树(见下图),每一层递归只能使问题规模减少1(去掉一个顶点),却要递归到第n层,显然,时间复杂度是 O ( n 2 ) O(n^2) O(n2)的。
在这里插入图片描述

比起点分治,这似乎更像树形dp,但树形dp的前提是每一个点的状态都可以借助其子结点 O ( 1 ) O(1) O(1)处理(在例题中可以看到一个既能用点分治,又能用树形dp来解决的例子)。但如果每一条链的状态都不能够合并,无法用dp来降低复杂度呢?实际上,点分治可以说就是为了解决一些无法保存状态的树形dp而发明的一种“优雅的暴力”。

点分治维护其复杂度的关键在于“找重心”。依然从最难处理的链来考虑问题,如果每次递归进入一颗子树时,先找子树的重心,再以重心来分割子树,可以发现每一次分割都使问题的规模减少为原来的一半,那么经过 O ( l o g n ) O(logn) O(logn)次分割,问题规模一定被减少到1,也就是每个顶点都已被考虑到。如果用递归树来刻画这一过程(见下图),可以发现树高为 l o g ( n ) log(n) log(n),每一层的时间之和都为 T ( n ) T(n) T(n),于是,我们成功将 O ( n 2 ) O(n^2) O(n2)的复杂度,通过找重心降低到 O ( n l o g n ) O(nlogn) O(nlogn)
在这里插入图片描述

分治

点分治的题目中,“分治”部分几乎是一成不变的,无非是找到重心后,以重心为根重新计算子树大小,并利用计算结果进一步去找子树的重心。简而言之就是两个函数Getroot和Getsz而已。

int Getroot(int x, int f) {
    int sum = 0, mx = 0, tmp;
    for (auto T : G[x]) {
        int to = T.to;
        if (vis[to] || to == f) continue;
        tmp = Getroot(to, x);
        sum += tmp;
        mx = max(mx, tmp);
    }
    sum++;
    mx = max(mx, tot - sum);
    if (mx < MN) {
        MN = mx;
        rt = x;
    }
    return sum;
}

void Getsz(int x, int f) {
    sz[x] = 0;
    for (auto T : G[x]) {
        int to = T.to;
        if (vis[to] || to == f) continue;
        Getsz(to, x);
        sz[x] += sz[to];
    }
    sz[x]++;
}

计算贡献

但这只是第一步。点分治中最重要,最灵活,也最需要投入思考的点,其实是一开始提到的如何计算“所有经过某一个顶点的链对答案的贡献”。

P3806 【模板】点分治1

以这道最基本的模板题来说。相对而言比较经典的计算过程是这样的。

1、一次性收集所有链信息,存放在一个vector或数组中。

void Getdis(int x, int len, int f) {
    D.push_back(len);
    for (auto T : G[x]) {
        int to = T.to, v = T.v;
        if (vis[to] || to == f) continue;
        Getdis(to, len + v, x);
    }
}

2、利用数据结构(这里用的是桶)计算链对答案的贡献

void calc(int x, int len, int type) {
    D.clear();
    Getdis(x, len, -1);
    for (int y : D) {
        if (y > M) continue;
        for (int i = 1; i <= m; i++) {
            if (q[i] - y < 0) continue;
            cnt[i] += type * bucket[q[i] - y];
        }
        bucket[y]++;
    }
    for (int y : D) {
        if (y > M) continue;
        bucket[y]--;
    }
}

3、先计算子树内所有链的贡献(cal(x, 0, 1)),再利用容斥原理,去除不合法的链(calc(to, v, -1))

void solve(int x) {
    vis[x] = 1;
    calc(x, 0, 1);
    for (auto T : G[x]) {
        int to = T.to, v = T.v;
        if (vis[to]) continue;
        calc(to, v, -1);
        tot = sz[to], MN = INF;
        Getroot(to, -1);
        Getsz(rt, -1);
        solve(rt);
    }
}

关于第三点或许需要再举个例子说明一下。
在这里插入图片描述

若当前以A为根收集链信息,一共可以得到五条链:

A

A->B

A->B->D

A->B->E

A->C

而用于计算答案的链,实际上是从收集到的链中任选两条,组合成一条新链再进行计算的。比如选择A->B和A->C,实际上是B->A->C这条链,选择A和A->B实际上就是A->B这条链,这两条都是合法的。但如果选择的是A->B->D和A->B->E时,选择的其实是D->B->A->B->E,但D到E的简单路径是D->B->E,也就是这条链不存在,这是不合法的。实际上,同时经过A和B的两条链组合之后总是不合法的,而cal(to, v, -1)的意义就在于选出这样的不合法的链,并把它们对答案的贡献消除。

完整的代码贴在这里

#include <bits/stdc++.h>
#define debug(x) cerr << #x << " : " << x << endl
using namespace std;
typedef long long LL;

const int N = 1e4 + 5, M = 1e7 + 5, P = 1e2 + 5, INF = 0x3f3f3f3f;

struct Edge {
    int to, v;
};

int n, m, k, q[P], cnt[P], bucket[M];

vector<int> D;

/**variables of tree divide*/
int tot, sz[N], MN, rt;

bool vis[N];

vector<Edge> G[N];

int Getroot(int x, int f) {
    int sum = 0, mx = 0, tmp;
    for (auto T : G[x]) {
        int to = T.to;
        if (vis[to] || to == f) continue;
        tmp = Getroot(to, x);
        sum += tmp;
        mx = max(mx, tmp);
    }
    sum++;
    mx = max(mx, tot - sum);
    if (mx < MN) {
        MN = mx;
        rt = x;
    }
    return sum;
}

void Getsz(int x, int f) {
    sz[x] = 0;
    for (auto T : G[x]) {
        int to = T.to;
        if (vis[to] || to == f) continue;
        Getsz(to, x);
        sz[x] += sz[to];
    }
    sz[x]++;
}

void Getdis(int x, int len, int f) {
    D.push_back(len);
    for (auto T : G[x]) {
        int to = T.to, v = T.v;
        if (vis[to] || to == f) continue;
        Getdis(to, len + v, x);
    }
}

void calc(int x, int len, int type) {
    D.clear();
    Getdis(x, len, -1);
    for (int y : D) {
        if (y > M) continue;
        for (int i = 1; i <= m; i++) {
            if (q[i] - y < 0) continue;
            cnt[i] += type * bucket[q[i] - y];
        }
        bucket[y]++;
    }
    for (int y : D) {
        if (y > M) continue;
        bucket[y]--;
    }
}

void solve(int x) {
    vis[x] = 1;
    calc(x, 0, 1);
    for (auto T : G[x]) {
        int to = T.to, v = T.v;
        if (vis[to]) continue;
        calc(to, v, -1);
        tot = sz[to], MN = INF;
        Getroot(to, -1);
        Getsz(rt, -1);
        solve(rt);
    }
}

int main() {
    cin >> n >> m;
    for (int i = 1, u, v, w; i <= n - 1; i++) {
        scanf("%d %d %d", &u, &v, &w);
        G[u].push_back({v, w});
        G[v].push_back({u, w});
    }
    for (int i = 1; i <= m; i++) scanf("%d", &q[i]);
    tot = n, MN = INF;
    Getroot(1, -1);
    Getsz(rt, -1);
    solve(rt);
    for (int i = 1; i <= m; i++) {
        printf("%s\n", cnt[i] ? "AYE" : "NAY");
    }
    return 0;
}

至此,一个完整的点分治就结束了。

但这并不是唯一的计算贡献的方法。同样以这道题目来说,我们其实不需要一次性收集所有的链,而是可以先收集某个子树中的链,根据桶的信息修改答案,把收集到的链存入桶中,再去下一个子树收集链信息。如此一来就可以避免计算不合法的链,也就不需要利用容斥来消除它们。

void calc(int x) {
    judge[0] = 1;
    for (auto T : G[x]) {
        int to = T.to;
        if (vis[to]) continue;
        D.clear();
        Getdis(to, T.v, x);
        for (int i : D) {
            for (int j = 1; j <= m; j++)
                if (query[j] >= i) ans[j] |= judge[query[j] - i];
        }
        for (int i : D) {
            if (i < 10000010 && !judge[i]) {
                C.push_back(i);
                judge[i] = 1;
            }
        }
    }
    for (int i : C) judge[i] = 0;
    C.clear();
}

我个人认为点分治只能说是一种思想,而不能称为一种算法,就在于每道题计算贡献的方法都不一样。因此,考虑是否要运用点分治解决问题,最关键的就是能否找到一个计算贡献的方法。许多计算方法都要套一个树状数组或者线段树之类的结构,这也都会对最终的时间复杂度产生影响。

例题

P2634 [国家集训队]聪聪可可

最简单的点分治, O ( n ) O(n) O(n)收集信息, O ( 1 ) O(1) O(1)修改答案即可。

#include <bits/stdc++.h>
#define debug(x) cerr << #x << " : " << x << endl
using namespace std;
typedef long long LL;

const int N = 2e5 + 5, INF = 0x3f3f3f3f;

struct Edge {
    int to, v;
};

vector<Edge> G[N];

bool vis[N];

int n, tot, sz[N], MN, rt, a[N];

LL ans[3], cnt[3];

int add(int x, int y) {return x + y >= 3 ? x + y - 3 : x + y;}

int Getroot(int x, int f) {
    int sum = 0, mx = 0, tmp;
    for (auto T : G[x]) {
        int to = T.to;
        if (vis[to] || to == f) continue;
        tmp = Getroot(to, x);
        sum += tmp;
        mx = max(mx, tmp);
    }
    sum++;
    mx = max(mx, tot - sum);
    if (mx < MN) {
        MN = mx;
        rt = x;
    }
    return sum;
}

void Getsz(int x, int f) {
    sz[x] = 0;
    for (auto T : G[x]) {
        int to = T.to;
        if (vis[to] || to == f) continue;
        Getsz(to, x);
        sz[x] += sz[to];
    }
    sz[x]++;
}

void Getdis(int x, int len, int f) {
    cnt[len]++;
    for (auto T : G[x]) {
        int to = T.to, v = T.v;
        if (vis[to] || to == f) continue;
        Getdis(to, add(len, v), x);
    }
}

void calc(int x, int len, int type) {
    memset(cnt, 0, sizeof(cnt));
    Getdis(x, len, -1);
    for (int i = 0; i <= 2; i++) {
        for (int j = 0; j <=2; j++) {
            ans[add(i, j)] += cnt[i] * cnt[j] * type;
        }
    }
}

void solve(int x) {
    vis[x] = 1;
    calc(x, 0, 1);
    for (auto T : G[x]) {
        int to = T.to, v = T.v;
        if (vis[to]) continue;
        calc(to, v, -1);
        tot = sz[to], MN = INF;
        Getroot(to, -1);
        Getsz(rt, -1);
        solve(rt);
    }
}

int main() {
    cin >> n;
    for (int i = 1, u, v, w; i <= n - 1; i++) {
        scanf("%d %d %d", &u, &v, &w);
        w %= 3;
        G[u].push_back({v, w});
        G[v].push_back({u, w});
    }
    tot = n, MN = INF;
    Getroot(1, -1);
    Getsz(rt, -1);
    solve(rt);
    LL gcd = __gcd(ans[0], ans[0] + ans[1] + ans[2]);
    printf("%lld/%lld", ans[0] / gcd, (ans[0] + ans[1] + ans[2]) / gcd);
    return 0;
}

可能因为要处理的信息太简单了,用树形dp一样可以过这道题,复杂度还更低。

#include <bits/stdc++.h>

using namespace std;
typedef long long LL;

const int N = 2e4 + 5;

int n;

LL dp[N][3], sum[N][3], ans[3];

struct Edge {
    int to, v;
};
vector<Edge> G[N];

int add(int x, int y) {return x + y >= 3 ? x + y - 3 : x + y;}

int sub(int x, int y) {return x - y < 0 ? x - y + 3 : x - y;}

void dfs(int x, int f) {
    dp[x][0]++;
    for (auto T : G[x]) {
        int to = T.to, v = T.v;
        if (to == f) continue;
        dfs(to, x);
        for (int i = 0; i <= 2; i++) {
            for (int j = 0; j <= 2; j++) {
                ans[add(i, j)] += dp[to][sub(i, v)] * dp[x][j];
            }
        }
        for (int i = 0; i <= 2; i++) {
            dp[x][i] += dp[to][sub(i, v)];
        }
    }
}

int main() {
    cin >> n;
    for (int i = 1, u, v, w; i <= n - 1; i++) {
        scanf("%d %d %d", &u, &v, &w);
        w %= 3;
        G[u].push_back({v,w});
        G[v].push_back({u,w});
    }
    dfs(1, -1);
    ans[0] *= 2;
    ans[1] *= 2;
    ans[2] *= 2;
    ans[0] += n;
    LL gcd = __gcd(ans[0], ans[0] + ans[1] + ans[2]);
    printf("%lld/%lld", ans[0] / gcd, (ans[0] + ans[1] + ans[2]) / gcd);
    return 0;
}

P4178 Tree

利用树状数组统计答案,非常经典。

#include <bits/stdc++.h>

using namespace std;

const int N = 4e4 + 5, M = 2e4 + 5, INF = 0x3f3f3f3f;

struct Edge {
    int to, v;
};

struct BIT {
    int t[N];
    int lowbit(int x) {return x &(-x);}
    void add(int x, int v) {
        x += 1;
        while(x < N) {
            t[x] += v;
            x += lowbit(x);
        }
    }
    int ask(int x) {
        x += 1;
        int res = 0;
        while(x) {
            res += t[x];
            x -= lowbit(x);
        }
        return res;
    }
}bit;

vector<Edge> G[N];

vector<int> D;

bool vis[N];

int n, k, tot, sz[N], MN, rt, ans;

int Getroot(int x, int f) {
    int sum = 0, mx = 0, tmp;
    for (auto T : G[x]) {
        int to = T.to;
        if (vis[to] || to == f) continue;
        tmp = Getroot(to, x);
        sum += tmp;
        mx = max(mx, tmp);
    }
    sum++;
    mx = max(mx, tot - sum);
    if (mx < MN) {
        MN = mx;
        rt = x;
    }
    return sum;
}

void Getsz(int x, int f) {
    sz[x] = 0;
    for (Edge T : G[x]) {
        int to = T.to;
        if (vis[to] || to == f) continue;
        Getsz(to, x);
        sz[x] += sz[to];
    }
    sz[x]++;
}

void Getdis(int x,int len, int f) {
    D.push_back(len);
    for (Edge T : G[x]) {
        int to = T.to, v = T.v;
        if (vis[to] || to == f) continue;
        Getdis(to, len + v, x);
    }
}

int calc(int x, int len) {
    int res = 0;
    Getdis(x, len, -1);
    D.push_back(0);
    for (int y : D) {
        if (y > k) continue;
        res += bit.ask(k - y);
        bit.add(y, 1);
    }
    for (int y : D) {
        if (y > k) continue;
        bit.add(y, -1);
    }
    D.clear();
    return res;
}

void solve(int x) {
    vis[x] = 1;
    ans += calc(x, 0);
    for (Edge T : G[x]) {
        int to = T.to, v = T.v;
        if (vis[to]) continue;
        ans -= calc(to, v);
        tot = sz[to], MN = INF;
        Getroot(to, -1);
        Getsz(rt, -1);
        solve(rt);
    }
}

int main() {
    cin >> n;
    for (int i = 1, u, v, w; i <= n - 1; i++) {
        scanf("%d %d %d", &u, &v, &w);
        G[u].push_back({v, w});
        G[v].push_back({u, w});
    }
    cin >> k;
    tot = n, MN = INF;
    Getroot(1, -1);
    Getsz(rt, -1);
    solve(rt);
    printf("%d\n", ans - n);
    return 0;
}

Constructing Ranches

需要利用排序、离散化、树状数组来计算贡献,可能还要卡卡常,比较清奇。

#include <bits/stdc++.h>

using namespace std;
typedef long long LL;

const int N = 2e5 + 5, INF = 0x3f3f3f3f;

struct BIT {
    int t[N];
    int lowbit(int x) {return x &(-x);}
    void add(int x, int v) {
        x++;
        while(x < N) {
            t[x] += v;
            x += lowbit(x);
        }
    }
    int ask(int x) {
        x++;
        int res = 0;
        while(x) {
            res += t[x];
            x -= lowbit(x);
        }
        return res;
    }
}bit;

struct Edge{
    LL len;
    int mx;
};
vector<int> G[N], dc, add;

vector<Edge> D;

unordered_map<LL, int> pos;

bool vis[N];

int T, n, tot, sz[N], MN, rt, a[N];

LL ans;

int Getroot(int x, int f) {
    int sum = 0, mx = 0, tmp;
    for (int to : G[x]) {
        if (vis[to] || to == f) continue;
        tmp = Getroot(to, x);
        sum += tmp;
        mx = max(mx, tmp);
    }
    sum++;
    mx = max(mx, tot - sum);
    if (mx < MN) {
        MN = mx;
        rt = x;
    }
    return sum;
}

void Getsz(int x, int f) {
    sz[x] = 0;
    for (int to : G[x]) {
        if (vis[to] || to == f) continue;
        Getsz(to, x);
        sz[x] += sz[to];
    }
    sz[x]++;
}

void Getdis(int x,LL len, int mx, int f) {
    D.push_back({len, mx});
    for (int to : G[x]) {
        if (vis[to] || to == f) continue;
        Getdis(to, len + a[to], max(mx, a[to]), x);
    }
}

LL calc(int x, LL len, int mx, int top) {
    dc.clear(); pos.clear(); D.clear(); add.clear();
    if (top == 0) top = a[x];
    LL res = 0;
    Getdis(x, len, mx, -1);
    sort(D.begin(), D.end(), [&](Edge u, Edge v) {
        return u.mx > v.mx;
    });
    for (auto y : D) {
        dc.push_back(y.len - top);
    }
    sort(dc.begin(), dc.end());
    dc.erase(unique(dc.begin(), dc.end()), dc.end());
    for (auto y : D) {
        int idx = lower_bound(dc.begin(), dc.end(), y.len - top) - dc.begin();
        res += bit.ask(idx);
        idx = lower_bound(dc.begin(), dc.end(), y.mx * 2 - y.len + 1) - dc.begin();
        if (idx >= 0) {
            bit.add(idx, 1);
            add.push_back(idx);
        }
    }
    for (auto y : add) bit.add(y, -1);
    return res;
}

void solve(int x) {
    vis[x] = 1;
    ans += calc(x, a[x], a[x], 0);
    for (int to : G[x]) {
        if (vis[to]) continue;
        ans -= calc(to, a[x] + a[to], max(a[x], a[to]), a[x]);
        tot = sz[to], MN = INF;
        Getroot(to, -1);
        Getsz(rt, -1);
        solve(rt);
    }
}

int main() {
    cin >> T;
    while(T--) {
        cin >> n;
        for (int i = 1; i <= n; i++) {
            vis[i] = 0;
            scanf("%d", &a[i]);
            G[i].clear();
        }
        ans = 0;
        for (int i = 1, u, v, w; i <= n - 1; i++) {
            scanf("%d %d", &u, &v);
            G[u].push_back(v);
            G[v].push_back(u);
        }
        tot = n, MN = INF;
        Getroot(1, -1);
        Getsz(rt, -1);
        solve(rt);
        printf("%lld\n", ans);
    }
    return 0;
}
  • 2
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值