题面
解法
这道题的点分治依然比较基础
- 将黑色的边变成1,白色的边变成-1,这样比较容易判定。
- 因为要满足路径中间存在一个点使得这个点可以将这条路径分成两段且长度为0,所以这样就变得不是特别容易处理。
- 考虑在枚举分治重心的时候,已经处理完了前面的子树,假设对于当前的子树中的一点 x x x,当前的深度为 d d d,那么前面的子树中一定要有一个深度为 − d -d −d的点,假设为 y y y。同时, x − y x-y x−y这条路径合法当且仅当 x x x到重心存在一个点 z z z使得 d [ z ] = d [ y ] d[z]=d[y] d[z]=d[y]或 y y y到重心存在一点 z z z使得 d [ z ] = d [ x ] d[z]=d[x] d[z]=d[x]
- 那么我们可以记 s [ i ] [ 0 / 1 ] s[i][0/1] s[i][0/1]表示前面的子树中深度为 i i i,不存在/存在它到当前分治重心路径上的一个点,使得它的深度 = i =i =i的点的个数。
- 考虑对于当前深度为 d d d的点 x x x,显然至少有 s [ − d ] [ 1 ] s[-d][1] s[−d][1]个点会和 x x x形成一条合法的路径。然后如果分治重心到 x x x的路径上存在一个深度为 − d -d −d的点,那么还会有 s [ − d ] [ 0 ] s[-d][0] s[−d][0]个点会和 x x x形成一条合法路径。
- 然后对于重心到当前点是否存在深度为 i i i的点直接在递归的时候用一个数组处理一下就可以了。
- 时间复杂度: O ( n log n ) O(n\log n) O(nlogn)
【注意事项】
- 需要考虑分治重心到当前点也可能存在一条合法路径的情况。
- 因为深度会出现负数,所以要 + n +n +n使得在数组中访问的是一个非负数。
代码
#include <bits/stdc++.h>
#define inf 1 << 30
#define N 100010
using namespace std;
template <typename T> void chkmax(T &x, T y) {x = x > y ? x : y;}
template <typename T> void read(T &x) {
x = 0; int f = 1; char c = getchar();
while (!isdigit(c)) {if (c == '-') f = -1; c = getchar();}
while (isdigit(c)) x = x * 10 + c - '0', c = getchar(); x *= f;
}
struct Edge {int next, num, v;} e[N * 3];
int n, rt, now, cnt, f[N], d[N], siz[N], vis[N], used[N * 2];
long long ans, s[N * 2][2];
void add(int x, int y, int v) {
e[++cnt] = (Edge) {e[x].next, y, v};
e[x].next = cnt;
}
void getr(int x, int fa) {
f[x] = 0, siz[x] = 1;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (k == fa || vis[k]) continue;
getr(k, x); siz[x] += siz[k];
chkmax(f[x], siz[k]);
}
chkmax(f[x], now - siz[x]);
if (f[x] < f[rt]) rt = x;
}
void calc(int x, int fa, int dep) {
d[x] = dep; ans += s[n - dep][1];
if (used[n + dep]) ans += s[n - dep][0];
if (dep == 0) ans += used[n] > 1;
used[n + dep]++;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (k == fa || vis[k]) continue;
calc(k, x, dep + v);
}
used[n + dep]--;
}
void update(int x, int fa, int dep) {
s[n + dep][used[n + dep] > 0]++;
used[n + dep]++;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (k == fa || vis[k]) continue;
update(k, x, dep + v);
}
used[n + dep]--;
}
void Clear(int x, int fa) {
s[d[x] + n][0] = s[d[x] + n][1] = 0;
for (int p = e[x].next; p; p = e[p].next)
if (!vis[e[p].num] && e[p].num != fa) Clear(e[p].num, x);
}
void work(int x) {
vis[x] = 1, used[n] = 1;
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num, v = e[p].v;
if (vis[k]) continue;
calc(k, x, v), update(k, x, v);
}
for (int p = e[x].next; p; p = e[p].next)
if (!vis[e[p].num]) Clear(e[p].num, x);
for (int p = e[x].next; p; p = e[p].next) {
int k = e[p].num;
if (vis[k]) continue;
f[rt = 0] = inf, now = siz[k];
getr(k, x), work(rt);
}
}
int main() {
read(n); cnt = n;
for (int i = 1; i < n; i++) {
int x, y, v;
read(x), read(y), read(v);
if (!v) v = -1;
add(x, y, v), add(y, x, v);
}
f[rt = 0] = inf, now = n;
getr(1, 0); work(rt);
cout << ans << "\n";
return 0;
}