浅谈点分治

浅谈点分治

如果整个点分治像是一个泳池的话,这篇文章所呈现的内容大概就是他旁边的水上乐园。并没有任何系统的介绍(会有很好很好的文章链接推荐)这篇文章大概就是在人云亦云的随便写写感觉。

(防止误导)
大佬博客的链接1
大佬博客的链接2

点分治在干什么?

通常有一类问题是在树上的路径统计。这种路径与树的形态无关,(通常不是一个起点终点都固定的链,或者不会限制不能向上走(朝向父亲)或者不能向下走)。
简而言之,点分治处理的问题通常与树本身的形态没有太大关系,因为点分治的处理过程会彻底的破坏原有的树的结构。


让我先插入一刀例题

题目描述

给定一棵有 n n n 个点的树,询问树上距离为 k k k 的点对是否存在。

输入格式

第一行两个数 n , m n,m n,m
2 2 2 到第 n n n 行,每行三个整数 u , v , w u, v, w u,v,w,代表树上存在一条连接 u u u v v v 边权为 w w w 的路径。
接下来 m m m 行,每行一个整数 k k k,代表一次询问。

数据要求

对于 100 % 100\% 100% 的数据,保证 1 ≤ n ≤ 1 0 4 1 \leq n\leq 10^4 1n104 1 ≤ m ≤ 100 1 \leq m\leq 100 1m100 1 ≤ k ≤ 1 0 7 1 \leq k \leq 10^7 1k107 1 ≤ u , v ≤ n 1 \leq u, v \leq n 1u,vn 1 ≤ w ≤ 1 0 4 1 \leq w \leq 10^4 1w104


做法1. 先说一个暴力的做法,先离线所有询问,然后无脑“嗯”求。简单讲就是枚举树上的两个点,然后暴力去求LCA,然后判断某个k值是否出现了。这种算法的复杂度是 O ( n 2 l o g n ) O(n^2logn) O(n2logn)

点分治的核心思想

任何分治思想说起来非常简单,因为如果一个分治思想说出来不简单那就不叫分治了
简单讲我们要统计具有某种性质的路径(这道题目中就是统计距离长度为k的路径数量是否为0),我们如何保证统计过程不重不漏呢?很简单,我们对树进行一次遍历,每次统计经过当前节点的合法路径数量,最后加起来就完事了。(意识到这一点很重要,同任何分治算法一样,我们只需要关注眼前的统计就好了,任何不经过当前节点的路径的统计工作都将由其他节点完成)

做法2. 一个比他聪明很多的方法是“贡献法”,也是“分治思想法”。枚举lca,然后统计以当前节点为LCA的节点对答案的贡献。简而言之就是枚举当前节点的所有子树。对于一个子树中的各个深度的节点首先查询相应的能拼成长度为k的节点是否出现过,当前子树全部遍历完之后,把当前子树的所有节点的深度放到一个桶中,然后继续遍历。(k的上限是可以接受的,所以可以直接开一个(表示不同深度是否出现的)桶存一下,不然需要map之类的多一个log的算法)这么做是为了防止存在子树自己和自己匹配的情况,这一点不允许是因为避免重复,因为这种情况只要遍历下去子树自己会处理(后面核心思想也会说到)。
那么这种算法的复杂度是多少呢?最坏情况下是一条链的时候,时间复杂度会被卡成 O ( n 2 ) O(n^2) O(n2),比暴力算法优秀,但是还是不能接受。

点分治的复杂度是如何保证的

回想一下分治排序的复杂度是如何保证的,这个算法和分治排序有异曲同工的妙处。分治排序每一层的复杂度都是O(n)算法,这并不快。但是分治算法很重要的一点是

不论原来的序列长度如何,他只会分治log(n)层。

分治排序算法每次都会把序列长度砍半,那我们思考一下点分治将如何保证呢?对!没错,就是每次从重心分治。回顾一下重心的定义。

重心的定义:“任何子树的大小都不超过整棵树的1/2”

那么也就是说,每次分后的子树都将至少减少一半,自然整棵树的分治也就不会超过log(n)层。所以不考虑具体处理过程的话,整个点分治算法的时间复杂度是nlog(n)的。

做法3.(正解) 其实就是在上面的做法上每次选择树的重心作为遍历对象,这样的好处是我们保证了最多只会遍历树log(n)次,再加上上面的做法,再加上对问题离线。时间复杂度就保证到了nlog(n)。

一些找重心的下头做法的避雷

模版代码:
好吧我承认我的代码注释写的不算详细,只是为了给大家一个文章参考(都说了是写写所思所想啦,如果想看详细注释,图文解释的参考大佬博客链接吧)

// 模版例题1 : 点分治 + 桶 / 双指针(见双指针部分)
// https://www.luogu.com.cn/problem/P3806
#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>

const int N = 1e4;
int head[N + 2], nxt[N * 2 + 2], ver[N * 2 + 2], tot, edge[N * 2 + 2];
int q[102], ans[N + 2], v[N + 2], sz[N + 2], mx[N + 2], S, root, d[N + 2];
int n, m;
bool t[10000004];
std::vector<int> rem;

void add(int x, int y, int z) {
	// 	建立图,没啥好说的
    tot++;
    edge[tot] = z;
    ver[tot] = y;
    nxt[tot] = head[x];
    head[x] = tot;
}

void find(int x, int fa) {
    sz[x] = 1, mx[x] = 0;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (y == fa || v[y]) continue;
        find(y, x);
        sz[x] += sz[y];
        mx[x] = std::max(mx[x], sz[y]);
    }
    mx[x] = std::max(mx[x], S - sz[x]);
    if (mx[x] < mx[root]) root = x;  // 找到树的重心
}

void getdis(int x, int fa) {
    rem.push_back(d[x]);  // 记录下本次根节点的当前子节点搜索的所有深度
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (y == fa || v[y]) continue;
        d[y] = d[x] + z;
        getdis(y, x);
    }
}

void calc(int x) {
    // 这里其实不是向下遍历,就只是当前节点的子节点
    std::vector<int> res;  // 记录下当前跟节点下的所有深度信息
    t[0] = 1;
    res.push_back(0);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (v[y]) continue;
        rem.clear();
        d[y] = z;  // 初始化深度
        getdis(y, x);
        for (int j = 0; j < (int)rem.size(); j++) {
            for (int k = 1; k <= m; k++) {
                if (q[k] >= rem[j] && t[q[k] - rem[j]]) ans[k] = 1;
            }
        }
        for (auto i : rem)
            if (i <= 1e7) t[i] = 1, res.push_back(i);
    }
    for (auto i : res) t[i] = 0;  // 其实就是 memset 0,但是这样时间复杂度会爆炸,所以才这样操作。
}

void solve(int x) {
    std::cout << x << std::endl;
    v[x] = 1;  // 保证了solve遍历不会出现循环
    calc(x);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y]) continue;
        find(y, x);  // 第一次其实就是求sz
        S = sz[y], root = 0, mx[0] = 1e9;
        find(y, x);
        solve(root);
    }
}

int main() {
    std::cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int a, b, c;
        std::cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    for (int i = 1; i <= m; i++) {
        std::cin >> q[i];
    }
    solve(1);
    for (int i = 1; i <= m; i++) printf(ans[i] ? "AYE\n" : "NAY\n");
    return 0;
}

例题选讲

例题一:点分治模版

题意:

统计长度为k的路径是否出现

思路:

k本身不大,直接暴力开桶存,离线所有询问后每次暴力查找是否有符合要求的即可

例题二:Race

题意:

给一棵树,每条边有权。求一条简单路径,权值和等于 k,且边的数量最小

思路:

每次维护子树中的两个值,一个是深度,一个是距离。找到所有距离等于k的情况,取深度和为最小的情况即可。同样,k不大,可以直接暴力开桶进行更新。

显然错误的思路:

这题直接用双指针的思路是显然错误的,至少未经过优化的双指针是这样的(也可能您想的思路直接就是优化过的思路,这里单纯为一部分同学避雷,有无病呻吟之嫌疑 ) 。因为处理不了如下情况。

// 假设我们想找距离为k=10的节点对来更新答案
// 距离:
	1 2 3 3 4 5 6 7 7 7 8 9
   	    ^             ^
  	  指针1          指针2

对于当前情况不论怎么移动指针,移动哪个指针都是错误的。因为所有满足要求的3和7组合的情况有6种,稍微模拟一下双指针就能发现不论怎么移动都是显然缺少一些情况的。自然就不能保证找到路径数量最少的情况。故错误。而上一道题由于只是判断是否出现,所以该算法在上一道题又是显然正确的。

一个显然的优化是把所有相同的值只保留一个到根节点路径数最小的值,其他都删掉,这样显然是正确的。

例题三:Tree

题意:

给定一棵 n 个节点的树,每条边有边权,求出树上两点距离小于等于 k 的点对数量。

思路:

k还是不大 ,所以直接开一个桶暴力存就好啦,每次查询的时间复杂度太高了?直接开一手树状数组就好了,太暴力了,不想多讲了。只要注意不要计算某一个子树自己对自己的贡献就好了。

例题四:消息传递

题意:

给一棵树,每次询问到一个目标节点的距离为k的节点数量

思路:

思考一下这道题目和模版题目(第一题)的显然区别是什么?

本质区别是这个问题具体到了某一个固定的点(首先我们离线处理所有问题,把所有问题都挂在点上),而不是对于整棵树的贡献了,也就是说在计算贡献的时候我们不能全都统计到一个总答案里,而是把具体对某个点的答案统计到某个点上。

意识到这一点十分重要,因为这意味着我们不经需要计算子树对当前节点的贡献,还需要把当前节点对子树对贡献挂到子树中具体的某个点上。这就不能简单的遍历,计算贡献,然后更新桶。这样简单的解决了,因为这样没有办法把“后面”的贡献传递到“前面”的子树上。

具体的,这一点实现需要注意不重不漏,通常有两种做法。

做法1: 正反两遍顺序遍历更新答案,因为正序会更新前面子树对当前子树对贡献,倒序会更新后面子树对自己的贡献,两遍正好补充不漏
双序遍历通常只适用于需要向儿子传递贡献的题目,因为不需要传递贡献的题目单循环就好了,完全没必要正反这么搞

做法2:采用容斥原理,通常写一个计算自己对自己的贡献的函数,然后对每个子树掉用一次,对根节点掉用一次,然后简单容易下即可。
当然这种做法有一点细节,在计算根节点自己对自己的贡献的时候,一定不要忘记把自己加入集合再计算(不要只想着把子树都合并然后忘了把自己加进去
这种做法所有情况都适用,不是只适用于需要向儿子传递贡献的做法

例题五:Ruri Loves Maschera

题意:

给一棵树,定义到一个节点不超过R,不小于L条路径数的路径为合法路径。一条路径的价值为路径上所有边权值的最大值。问所有合法路径的价值总和。

思路:

首先这又是一道全体贡献的题目,不涉及到对于某个具体点的求解,也就不用父亲向孩子进行传递了。

这道题目的核心难点在于要维护最大值,对于一条路径的最大值尚好维护,但是难点在于经过根节点操作,“拼接长度”后仍然维护最大值。也就是说对于一条路径,我们完全不知道拼接后最大值是否会被更新,一个显然的思想是,(称当前路径的最大值为原本值)对于其他子树中所有最大值小于原本值的路径贡献都是原本值,对于所有最大值大于原本值的路径贡献都是该路径的最大值。但是这样不太好维护。

这里我们用一个十分巧妙的贡献思想。思考一下一个值什么时候会产生贡献,只有当他是最大值的时候!!!换句话说,我们只要记统计拼接的另一半最大值比当前最大值小的合法(指路径数量)段的数量然后乘上当前最大值就可以了!!!这就是当前段对答案的贡献!!!这就转换成了一个二维数点问题,很经典的对于第一维度排序,对于第二维度用树状数组统计就好了。

这里还有一个小技巧,就是我们选择“最大值”为第一维度,路径数量为第二维度,因为这样可以省去离散化的操作。

例题六:聪聪与可可

题意:

给一棵树,统计有多少对节点,节点间的路径长度是3的倍数

思路:

在统计的过程中只要开一个%3=0,1,2的桶然后去计算贡献就好了,我这道题目用的是容斥法,其实由于不涉及到向儿子传递贡献,只要单循环遍历计算贡献就好了。

题目的代码

所有代码都放在最后了,按需自取。

// 根本原因是双指针没有办法遍历所有==k的情况,所以其实就是假的算法

// Race #95分未优化双指针多遍骗分法

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>

#define int long long
const int N = 2e5;
int head[N + 2], nxt[N * 2 + 2], ver[N * 2 + 2], tot, edge[N * 2 + 2];
int ans, v[N + 2], sz[N + 2], mx[N + 2], S, root, d[N + 2], bel[N + 2], dep[N + 2];
int n, m;
std::vector<int> rem;

void add(int x, int y, int z) {
    tot++;
    edge[tot] = z;
    ver[tot] = y;
    nxt[tot] = head[x];
    head[x] = tot;
}

void find(int x, int fa) {
    sz[x] = 1, mx[x] = 0;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (y == fa || v[y]) continue;
        find(y, x);
        sz[x] += sz[y];
        mx[x] = std::max(mx[x], sz[y]);
    }
    mx[x] = std::max(mx[x], S - sz[x]);
    if (mx[x] < mx[root]) root = x;  // 这里感觉没必要,按理说找到一个就可以返回了,可能是和第一次便利有关吧
}

void getdis(int x, int fa, int dis, int from) {
    rem.push_back(x);
    dep[x] = dep[fa] + 1;
    d[x] = dis;
    bel[x] = from;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (y == fa || v[y]) continue;
        getdis(y, x, dis + z, from);
    }
}

bool cmp(int x, int y) { return d[x] < d[y]; }
void calc(int x) {
    rem.clear();
    rem.push_back(x);
    d[x] = 0;
    bel[x] = x;  // belong
    dep[x] = 0;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (v[y]) continue;
        getdis(y, x, z, y);
    }
    std::sort(rem.begin(), rem.end(), cmp);  // 按照深度排序

    int l = 0, r = rem.size() - 1;
    // std::cout << x << ": " << rem.size() << std::endl;
    while (l < r) {
        if (d[rem[l]] + d[rem[r]] > m)
            r--;
        else if (d[rem[l]] + d[rem[r]] < m)
            l++;
        else if (bel[rem[l]] == bel[rem[r]]) {
            if (d[rem[r]] == d[rem[r - 1]])
                r--;
            else
                l++;
        } else {
            ans = std::min(ans, dep[rem[l]] + dep[rem[r]]);
            if (d[rem[r]] == d[rem[r - 1]])
                r--;
            else
                l++;
            // break;
        }
    }
    l = 0, r = rem.size() - 1;
    while (l < r) {
        if (d[rem[l]] + d[rem[r]] > m)
            r--;
        else if (d[rem[l]] + d[rem[r]] < m)
            l++;
        else if (bel[rem[l]] == bel[rem[r]]) {
            if (d[rem[l]] == d[rem[l + 1]])
                l++;
            else
                r--;
        } else {
            ans = std::min(ans, dep[rem[l]] + dep[rem[r]]);
            if (d[rem[l]] == d[rem[l + 1]])
                l++;
            else
                r--;
            // break;
        }
    }
}

void solve(int x) {
    v[x] = 1;  // 保证了solve遍历不会出现循环
    calc(x);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y]) continue;
        find(y, x);  // 第一次其实就是求sz
        S = sz[y], root = 0, mx[0] = 1e9;
        find(y, x);  // 第二次是找到重心
        solve(root);
    }
}

signed main() {
    // freopen("/Users/chenjiayou/Downloads/P4149_2.in", "r", stdin);
    std::cin >> n >> m;
    for (int i = 1; i < n; i++) {
        int a, b, c;
        std::cin >> a >> b >> c;
        a++, b++;
        add(a, b, c), add(b, a, c);
    }
    ans = 1e15;
    solve(1);
    std::cout << (ans == 1e15 ? -1 : ans) << std::endl;
    return 0;
}
// Tree

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <iostream>
#include <vector>

const int N = 4e4;
int head[N + 2], nxt[N * 2 + 2], edge[N * 2 + 2], ver[N * 2 + 2], tot;
int n, k, ans, v[N + 2], sz[N + 2], S, mx[N + 2], root, tr[N + 2], d[N + 2], tr0;
std::vector<int> rem;

int lowbit(int x) { return x & (-x); }
void plus(int p, int x) {
    // 这里注意不能给0增加,所以单独开一个节点保存0
    if (p == 0) {
        tr0 += x;
        return;
    }
    while (p <= k) {
        tr[p] += x;
        p += lowbit(p);
    }
}
int query(int x) {
    if (x < 0) return 0;
    int res = 0;
    while (x) {
        res += tr[x];
        x -= lowbit(x);
    }
    return res + tr0;  // 询问是非负整数,也就>=0
}

void add(int x, int y, int z) {
    tot++;
    ver[tot] = y;
    edge[tot] = z;
    nxt[tot] = head[x];
    head[x] = tot;
}

void find(int x, int fa) {
    // std::cout << x << " " << fa << std::endl;
    sz[x] = 1, mx[x] = 0;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (y == fa || v[y]) continue;
        find(y, x);
        sz[x] += sz[y];
        mx[x] = std::max(mx[x], sz[y]);
    }
    mx[x] = std::max(mx[x], S - sz[x]);
    if (mx[x] < mx[root]) root = x;
}

void getdis(int x, int fa) {
    // std::cout << x << " " << fa << std::endl;
    rem.push_back(d[x]);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (y == fa || v[y]) continue;
        d[y] = d[x] + z;
        getdis(y, x);
    }
}

void calc(int x) {
    // std::cout << x << ": " << std::endl;
    std::vector<int> res;
    res.push_back(0);
    plus(0, 1);  // 这里一定要注意在树状数组中特殊处理0的情况
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (v[y]) continue;
        rem.clear();
        d[y] = z;
        getdis(y, x);
        for (auto j : rem) {
            ans += query(k - j);
        }
        for (auto j : rem) {
            res.push_back(j);
            plus(j, 1);
        }
    }
    for (auto i : res) {
        plus(i, -1);
    }
}

void solve(int x) {
    v[x] = 1;
    calc(x);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y]) continue;
        find(y, x);
        root = 0, S = sz[y], mx[root] = 1e9;
        find(y, x);
        solve(y);
    }
}

int main() {
    std::cin >> n;
    for (int i = 1; i < n; i++) {
        int a, b, c;
        std::cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    std::cin >> k;
    solve(1);
    std::cout << ans << std::endl;
    return 0;
}
// 消息传递

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>

const int N = 1e5;
int head[N + 2], ver[N * 2 + 2], nxt[N * 2 + 2], tot;
int v[N + 2], sz[N + 2], mx[N + 2], S, root, d[N + 2];
int t[N + 2], n, m, ans[N + 2];
std::vector<int> rem;
std::vector<std::pair<int, int> > q[N + 2], pq;

void add(int x, int y) {
    tot++;
    ver[tot] = y;
    nxt[tot] = head[x];
    head[x] = tot;
}

void getdis(int x, int fa) {
    rem.push_back(d[x]);
    for (auto i : q[x]) {
        if (i.first < (d[x] - 1)) continue;
        pq.push_back({i.first - (d[x] - 1), i.second});  // 要的就是更改之后的d
    }
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (y == fa || v[y]) continue;
        d[y] = d[x] + 1;
        getdis(y, x);
    }
}

void calc(int x) {
    // std::cout << x << ": " << std::endl;
    std::vector<int> res, yy;
    d[x] = 1;
    // res.push_back(1);
    // t[1] = 1;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y]) continue;
        yy.push_back(y);
        rem.clear();
        pq.clear();
        d[y] = d[x] + 1;
        getdis(y, x);
        for (auto j : pq) {
            ans[j.second] += t[j.first + 1];
        }
        for (auto j : rem) {
            t[j] += 1;
            res.push_back(j);
        }
    }
    // std::cout << d[2] << " " << d[6] << std::endl;
    // for (auto i : pq) {
    //     std::cout << i.first << " " << i.second << std::endl;
    // }
    for (auto i : res) t[i] = 0;
    res.clear();
    d[x] = 1;
    res.push_back(1);
    t[1] = 1;
    for (int i = yy.size() - 1; i >= 0; i--) {
        int y = yy[i];
        rem.clear();
        pq.clear();
        d[y] = d[x] + 1;
        getdis(y, x);
        for (auto j : pq) {
            ans[j.second] += t[j.first + 1];
        }
        for (auto j : rem) {
            t[j] += 1;
            res.push_back(j);
        }
    }
    for (auto i : q[x]) {
        // std::cout << i.first << " :b " << i.second << " " << t[i.first + 1] << " " << ans[i.second] << std::endl;
        ans[i.second] += t[i.first + 1];
    }
    for (auto i : res) t[i] = 0;
}

void find(int x, int fa) {
    sz[x] = 1, mx[x] = 0;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (y == fa || v[y]) continue;
        find(y, x);
        sz[x] += sz[y];
        mx[x] = std::max(mx[x], sz[y]);
    }
    mx[x] = std::max(mx[x], S - sz[x]);
    if (mx[x] < mx[root]) root = x;
}

void dfs(int x) {
    v[x] = 1;
    calc(x);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y]) continue;
        find(y, x);
        S = sz[y], root = 0, mx[0] = 1e9;
        find(y, x);
        dfs(root);
    }
}

void solve() {
    tot = 0;
    std::cin >> n >> m;
    for (int i = 1; i <= n; i++) head[i] = 0, q[i].clear(), v[i] = 0;
    for (int i = 1; i <= m; i++) ans[i] = 0;
    for (int i = 1; i < n; i++) {
        int a, b;
        std::cin >> a >> b;
        add(a, b), add(b, a);
    }
    for (int i = 1; i <= m; i++) {
        int a, b;
        std::cin >> a >> b;
        q[a].push_back({b, i});
    }
    dfs(1);
    for (int i = 1; i <= m; i++) {
        std::cout << ans[i] << std::endl;
    }
}

int main() {
    std::ios::sync_with_stdio(0);
    std::cin.tie(0);
    std::cout.tie(0);
    int times;
    std::cin >> times;
    while (times--) {
        solve();
    }
    return 0;
}

/*
2
4 2
1 2
2 3
3 4
1 1
2 2

6 2
1 2
2 3
2 4
2 5
5 6
2 0
6 0

*/
// Ruri Loves Maschera

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>

#define int long long
const int N = 1e5;
int n, l, r;
int head[N * 2 + 2], ver[N * 2 + 2], edge[N * 2 + 2], nxt[N * 2 + 2];
int tr[N + 2], v[N + 2], S, root, sz[N + 2], d[N + 2], val[N + 2], mx[N + 2];
int tot, ans[N + 2];
std::vector<std::pair<int, int>> rem;

int lowbit(int x) {
    return x & (-x);
}
void plus(int p, int x) {
    while (p <= n) {
        tr[p] += x;
        p += lowbit(p);
    }
}
int query(int p) {
    int res = 0;
    while (p) {
        res += tr[p];
        p -= lowbit(p);
    }
    return res;
}
inline int sum(int x, int y) {
    return query(y) - query(x - 1);
}

void add(int x, int y, int z) {
    tot++;
    ver[tot] = y;
    edge[tot] = z;
    nxt[tot] = head[x];
    head[x] = tot;
}

void find(int x, int fa) {
    sz[x] = 1, mx[x] = 0;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y] || fa == y) continue;
        find(y, x);
        sz[x] += sz[y];
        if (sz[y] > mx[x]) mx[x] = sz[y];
    }
    mx[x] = std::max(mx[x], S - sz[x]);
    if (mx[x] < mx[root]) root = x;
}

void getdis(int x, int fa) {
    rem.push_back({val[x], d[x]});
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (v[y] || y == fa) continue;
        d[y] = d[x] + 1;
        val[y] = std::max(val[x], z);
        getdis(y, x);
    }
}

void bio_count(std::vector<std::pair<int, int>> &p, int &res) {
    // 第一维度用排序(最大值),第二维度用树状数组(长度)
    std::sort(p.begin(), p.end());
    // plus(1, 1);  // 把根节点加入到待数点的序列
    for (int i = 0; i < (int)p.size(); i++) {
        if (p[i].second - 1 > r) continue;
        // std::cout << l - p[i].second << std::endl;
        res += sum(std::max(1LL, l - (p[i].second - 1) + 1), r - (p[i].second - 1) + 1) * p[i].first;  // 这样应该就是不重不漏了,不用把所有当前的第一维都整理完再第二层
        plus(p[i].second, 1);
    }
    for (int i = 0; i < (int)p.size(); i++) {
        plus(p[i].second, -1);
    }
    // plus(1, -1);
}

void calc(int x) {
    // std::cout << x << std::endl;
    std::vector<std::pair<int, int>> res;
    std::vector<int> yy;
    int cntall = 0, ls = 0;
    // 可能需要特殊处理当前节点
    d[x] = 1;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (v[y]) continue;
        yy.push_back(y);
        rem.clear();
        d[y] = 2;
        val[y] = z;
        getdis(y, x);
        ls = 0;
        // std::cout << y << std::endl;
        bio_count(rem, ls);
        cntall += ls;
        for (auto j : rem) {
            res.push_back(j);
            // std::cout << j.first << " " << j.second << std::endl;
        }
        // std::cout << x << " " << y << ": " << cntall << std::endl;
    }
    ls = 0;
    plus(1, 1);
    bio_count(res, ls);
    plus(1, -1);
    // std::cout << x << ": " << cntall << " " << ls << std::endl;
    ans[x] += (ls - cntall) * 2;
}

void dfs(int x) {
    // std::cout << x << std::endl;
    v[x] = 1;
    calc(x);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y]) continue;
        find(y, x);
        S = sz[y], root = 0, mx[0] = 1e9;
        find(y, x);
        dfs(root);
    }
}

void solve() {
    std::cin >> n >> l >> r;
    for (int i = 1; i < n; i++) {
        int a, b, c;
        std::cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    dfs(1);
    int anss = 0;
    for (int i = 1; i <= n; i++) {
        anss += ans[i];
        // std::cout << ans[i] << " \n"[i == n];
    }
    std::cout << anss << std::endl;
}

signed main() {
    std::ios::sync_with_stdio(0);
    std::cin.tie(0), std::cout.tie(0);
    int times = 1;
    // std::cin>>times;
    while (times--) {
        solve();
    }
    return 0;
}

/*
5 2 3
1 2 2
2 3 2
3 4 4
4 5 5

*/
// 聪聪与可可

#include <algorithm>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <vector>

#define int long long
const int N = 2e4;
int head[N + 2], nxt[N * 2 + 2], ver[N * 2 + 2], tot, edge[N * 2 + 2];
int v[N + 2], d[N + 2], S, sz[N + 2], root, mx[N + 2];
std::vector<int> rem;
int cnt[4], ans;

void add(int x, int y, int z) {
    tot++;
    ver[tot] = y;
    edge[tot] = z;
    nxt[tot] = head[x];
    head[x] = tot;
}

void find(int x, int fa) {
    sz[x] = 1, mx[x] = 0;
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y] || y == fa) continue;
        find(y, x);
        sz[x] += sz[y];
        mx[x] = std::max(mx[x], sz[y]);
    }
    mx[x] = std::max(mx[x], S - sz[x]);
    if (mx[x] < mx[root]) root = x;
}

void deal(std::vector<int> &q, int &result) {
    memset(cnt, 0, sizeof cnt);
    for (auto i : q) {
        int ls = i % 3;
        if (ls != 0) ls = 3 - ls;
        result += cnt[ls];
        cnt[i % 3]++;
    }
    // std::cout << result << "?\n";
}

void getdis(int x, int fa) {
    rem.push_back(d[x]);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (v[y] || y == fa) continue;
        d[y] = d[x] + z;
        getdis(y, x);
    }
}

void calc(int x) {
    int cntall = 0, ls = 0;
    std::vector<int> res;
    res.push_back(0);  // 只有这里计算根节点,子树里面本题目不考虑根节点,因为这样情况在本题目中是合法的,不需要被容斥掉
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        int z = edge[i];
        if (v[y]) continue;
        rem.clear();
        d[y] = z;
        getdis(y, x);
        ls = 0;
        // std::cout << rem.size() << std::endl;
        deal(rem, ls);
        cntall += ls;
        for (auto j : rem) res.push_back(j);
        // std::cout << y << ":: " << cntall << std::endl;
    }
    ls = 0;
    deal(res, ls);
    // std::cout << x << ": " << ls << " " << cntall << std::endl;
    ans += ls - cntall;
}

void dfs(int x) {
    // std::cout << x << std::endl;
    v[x] = 1;
    calc(x);
    for (int i = head[x]; i; i = nxt[i]) {
        int y = ver[i];
        if (v[y]) continue;
        find(y, x);
        S = sz[y], root = 0, mx[0] = 1e9;
        find(y, x);
        dfs(root);
    }
}

int gcd(int x, int y) {
    return y ? gcd(y, x % y) : x;
}

signed main() {
    int n;
    std::cin >> n;
    for (int i = 1; i < n; i++) {
        int a, b, c;
        std::cin >> a >> b >> c;
        add(a, b, c), add(b, a, c);
    }
    dfs(1);
    ans = ans * 2 + n;
    int ls = gcd(ans, n * n);
    ans /= ls;
    std::cout << ans << "/" << (n * n / ls) << std::endl;
    return 0;
}
  • 14
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值