题意:
点分治模板,问树上两点路径权值和为 k k k的无序对数目。
分析:
点分治:每次在无根树中选取一个点为根,然后递归处理以根节点的儿子为根的子树。(选取哪个为根会影响效率) 人话:递归子树
对于以
r
o
o
t
root
root为根的树,那么树上两点
u
,
v
u,v
u,v路径有两种情况:
1.
1.
1.路径经过
r
o
o
t
root
root
2.
2.
2.路径不经过
r
o
o
t
root
root
我们可以只统计路径经过
r
o
o
t
root
root的
u
,
v
u,v
u,v对,而不经过
r
o
o
t
root
root的情况递归子树即可解决。
设
d
i
s
[
i
]
dis[i]
dis[i]是
i
i
i节点到
r
o
o
t
root
root的距离,
b
e
l
o
n
g
[
i
]
=
x
belong[i]=x
belong[i]=x表示
i
i
i节点在以
r
o
o
t
root
root儿子为
x
x
x的子树中。
那么我们需要计算的就是:
d
e
e
p
[
u
]
+
d
e
e
p
[
v
]
=
=
k
deep[u]+deep[v]==k
deep[u]+deep[v]==k的个数减去
d
e
e
p
[
u
]
+
d
e
e
p
[
v
]
=
=
k
,
i
f
(
b
e
l
o
n
g
[
u
]
=
=
b
e
l
o
n
g
[
v
]
)
deep[u]+deep[v]==k,if(belong[u]==belong[v])
deep[u]+deep[v]==k,if(belong[u]==belong[v])的个数,因为我们统计的是经过
r
o
o
t
root
root的路径而在同一棵子树中肯定不会经过
r
o
o
t
root
root,不经过
r
o
o
t
root
root的路径我们递归这个子树,再统计这个子树中路径经过
r
o
o
t
root
root的路径就可以了。
显然,根的选取影响我们递归的效率。
这就有了树的重心:将树的重心删去之后,剩下节点最多的树的节点数最小。
我们选取树的重心为根是可以保证效率
O
(
l
o
g
n
)
O(log n)
O(logn)的,证明可看漆子超论文。
树的重心只需要简单的
d
f
s
dfs
dfs就可求出。
参考论文:分治算法在树的路径问题中的应用
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e4 + 5;
const int maxm = 10 + 5;
const int INF = 0x3f3f3f3f;
const int mod = 1e9 + 7;
int n, m, k, cnt, tot, now, head[maxn];
int rt, son[maxn], dis[maxn], siz[maxn];
int ans[10000000 + 5];
bool vis[maxn];
struct node{
int v, w, nxt;
}e[maxn << 1];
void addedge(int u, int v, int w){
e[++cnt].v = v;
e[cnt].w = w;
e[cnt].nxt = head[u];
head[u] = cnt;
}
void getRoot(int x, int f){
siz[x] = 1, son[x] = 0;
for(int i = head[x]; i; i = e[i].nxt){
int v = e[i].v;
if(v == f || vis[v]) continue;
getRoot(v, x);
siz[x] += siz[v];
son[x] = max(son[x], siz[v]);
}
son[x] = max(son[x], now - siz[x]);
if(son[rt] > son[x] || son[rt] == 0) rt = x;
}
void getDis(int x, int f, int len){
dis[++tot] = len;
for(int i = head[x]; i; i = e[i].nxt){
if(e[i].v == f || vis[e[i].v]) continue;
getDis(e[i].v, x, len + e[i].w);
}
}
void solve(int x, int op, int len){
tot = 0;
getDis(x, 0, len);
for(int i = 1; i <= tot; i++)
for(int j = i + 1; j <= tot; j++)
ans[dis[i] + dis[j]] += op;
}
void dfs(int x){
rt = 0, getRoot(x, 0), x = rt;
vis[x] = 1;
solve(x, 1, 0);
for(int i = head[x]; i; i = e[i].nxt){
int v = e[i].v;
if(vis[v]) continue;
solve(v, -1, e[i].w);
now = siz[v];
dfs(v);
}
}
int main() {
scanf("%d %d", &n, &m);
for(int i = 1, u, v, w; i <= n - 1; i++){
scanf("%d %d %d", &u, &v, &w);
addedge(u, v, w), addedge(v, u, w);
}
now = n;
getRoot(1, 0);
dfs(rt);
while(m--){
scanf("%d", &k);
if(ans[k] > 0) puts("AYE");
else puts("NAY");
}
return 0;
}