对树的节点进行递归求解:求重心
r
t
rt
rt,计算经过
r
t
rt
rt的链对答案的贡献,向子树递归。
每次求重心为了防止树退化成链,保证递归层数为
l
o
g
n
logn
logn
Luogu P3806 点分治模板
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e4 + 10 ;
inline int read () {
int x = 0, f = 1; char c; c = getchar() ;
while (c < '0' || c > '9') {if (c == '-') f = -1; c = getchar();}
while (c >= '0' && c <= '9') {x = x * 10 + c - '0', c = getchar();}
return x * f ;
}
int n, m, sum, rt ;
int head[maxn], to[maxn * 2], cost[maxn * 2], nxt[maxn * 2], tot = 1 ;
int sz[maxn], f[maxn], dis[maxn], rem[maxn], q[maxn] ;
int query[110], ans[110] ;
bool vis[maxn], judge[10000010] ;
void addEdge (int u, int v, int c) {
to[++ tot] = v; cost[tot] = c; nxt[tot] = head[u]; head[u] = tot ;
}
void get_rt (int v, int fa) {
sz[v] = 1; f[v] = 0 ;
for (int i = head[v]; i; i = nxt[i]) {
if (to[i] == fa || vis[to[i]]) continue ;
get_rt (to[i], v) ;
sz[v] += sz[to[i]] ;
f[v] = max (f[v], sz[to[i]]) ;
}
f[v] = max (f[v], sum - sz[v]) ;
if (f[v] < f[rt]) rt = v ;
}
void get_dis (int v, int fa) {
rem[++ rem[0]] = dis[v] ;
for (int i = head[v]; i; i = nxt[i]) {
if (vis[to[i]] || to[i] == fa) continue ;
dis[to[i]] = dis[v] + cost[i] ;
get_dis (to[i], v) ;
}
}
void calc (int v) {
int p = 0 ;
for (int i = head[v]; i; i = nxt[i]) {
if (vis[to[i]]) continue ;
rem[0] = 0; dis[to[i]] = cost[i] ;
get_dis (to[i], v) ;
for (int j = rem[0]; j; j --)
for (int k = 1; k <= m; k ++)
if (rem[j] <= query[k]) ans[k] |= judge[query[k] - rem[j]] ;
for (int j = rem[0]; j; j --) q[++ p] = rem[j], judge[rem[j]] = 1 ;
}
for (int i = 1; i <= p; i ++) judge[q[i]] = 0 ;
}
void solve (int v) {
vis[v] = judge[0] = 1; calc (v) ;
for (int i = head[v]; i; i = nxt[i]) {
if (vis[to[i]]) continue ;
sum = sz[to[i]]; f[rt = 0] = n ;
get_rt (to[i], 0); solve (rt) ;
}
}
int main() {
n = read(); m = read() ;
for (int i = 1; i < n; i ++) {
int u = read(), v = read(), c = read() ;
addEdge (u, v, c); addEdge (v, u, c) ;
}
for (int i = 1; i <= m; i ++)
query[i] = read() ;
sum = n; f[rt] = n; get_rt (1, 0) ;
solve (rt) ;
for (int i = 1; i <= m; i ++) printf("%s\n", ans[i] ? "AYE" : "NAY") ;
return 0 ;
}