——点分治
原题传送门
前言
这里埋着一个 T r e a p Treap Treap的灵魂
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Data
Input
The input contains several test cases. The first line of each test case contains two integers n, k. The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
n<=10000
前言
去参加
Z
J
J
H
ZJJH
ZJJH 的神奇集训,发现神奇奆老,被虐了一批…
这是
D
a
y
3
Day3
Day3的一道开局入门板子题异常难的题目.
思路
点分治的模板题.
何为"点分治"?就是点的分治.
比如有一棵树:
那么我们如果要求这个问题:
∑
s
,
t
(
d
i
s
(
s
,
t
)
⩽
k
)
\sum_{s,t} (dis(s,t) \leqslant k)
s,t∑(dis(s,t)⩽k)
显然,如果我们变成这个形式,就很happy一些:
∑
s
,
t
(
d
i
s
(
s
,
t
)
⩽
k
&
L
C
A
(
s
,
t
)
=
r
o
o
t
)
\sum_{s,t} (dis(s,t) \leqslant k\ \&\ LCA(s,t)=root )
s,t∑(dis(s,t)⩽k & LCA(s,t)=root)
也就是说
s
−
t
s-t
s−t路径经过根节点.
那么直接枚举每个子树即可…(这个比较trival了吧)
然后,我们发现其它子树也有这个性质.
所以,很自然地想到点分治.
点分治
把所有情况分成两种情况:
- 过根结点的情况
- 在根结点各个子树内的情况.
注意发现 (2)是包含在(1)中 的.
我们往各个子树递归的过程中,会处理(2)的情况.
所以我们致力于处理(1)的情况,
步骤如下:
- 计算各结点到根结点的距离.
- 然后遍历所有结点,寻找
d
i
s
[
s
]
+
d
i
s
[
t
]
⩽
k
dis[s]+dis[t] \leqslant k
dis[s]+dis[t]⩽k就可以了…
但是这里有一点需要注意.
我们比较难以判断 L C A ( s , t ) = r o o t LCA(s,t)=root LCA(s,t)=root的情况. - 怎么办?容斥!
我们先不管 L C A ( s , t ) = r o o t LCA(s,t)=root LCA(s,t)=root的限制,计算所有 d i s ( s , t ) ⩽ k dis(s,t)\leqslant k dis(s,t)⩽k的情况. *1
减去 L C A ( s , t ) ≠ r o o t LCA(s,t)\neq root LCA(s,t)=root的情况. *2
然后我们发现,由于递归,我们只需**减去下一层子树的(*1)**即可.
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
const int N=1e4+5;
const int INF=0x3f3f3f3f;
int head[N],nxt[2*N],to[2*N],val[2*N];
int sz[N],a[N],dis[N],d[N];
bool vis[N];
int n,K,cnt,R,ans,V;
void init()
{
cnt=0;
memset(head,0,sizeof(head));
memset(nxt,0,sizeof(nxt));
memset(to,0,sizeof(to));
memset(val,0,sizeof(val));
memset(vis,false,sizeof(vis));
}
void addedge(int u,int v,int w)
{
nxt[++cnt]=head[u];head[u]=cnt;to[cnt]=v;val[cnt]=w;
}
void findroot(int u,int f)
{
sz[u]=1,a[u]=0;
for (int i=head[u];i;i=nxt[i])
{
int v=to[i];
if (v==f||vis[v]) continue;
findroot(v,u);
sz[u]+=sz[v];
a[u]=max(sz[v],a[u]);
}
a[u]=max(V-sz[u],a[u]);
if (a[u]<a[R]) R=u;
}
int getdis(int u,int f)
{
dis[++cnt]=d[u];
for (int i=head[u];i;i=nxt[i])
{
int v=to[i];
if (v==f||vis[v]) continue;
d[v]=d[u]+val[i];
getdis(v,u);
}
}
int calc(int u,int tag)
{
d[u]=tag;
cnt=0;
getdis(u,0);
sort(dis+1,dis+cnt+1);
int l=1,r=cnt,k=0;
while (l<r)
{
if (dis[l]+dis[r]<=K)
{
k+=r-l;
l++;
}
else r--;
}
return k;
}
void solve(int u)
{
ans+=calc(u,0);
vis[u]=true;
for (int i=head[u];i;i=nxt[i])
{
int v=to[i];
if (vis[v]) continue;
ans-=calc(v,val[i]);
V=sz[v];R=0;
a[0]=INF;
findroot(v,0);
solve(R);
}
}
int main()
{
while (~scanf("%d%d",&n,&K))
{
if (n==0||K==0) break;
init();
for (int i=1;i<n;i++)
{
int a,b,w;
scanf("%d%d%d",&a,&b,&w);
addedge(a,b,w),addedge(b,a,w);
}
R=0;a[0]=INF;
V=n;
findroot(1,0);
ans=0;
solve(R);
printf("%d\n",ans);
}
return 0;
}
感谢奆老关注 qwq ?