POJ1741 Tree 题解
题目大意
给一颗 n n 个节点的树,每条边上有一个距离。定义 d(u,v) d ( u , v ) 为 u u 到的最小距离。给定 k k 值,求有多少点对使 u u 到的距离小于等于 k k 。
解题思路
如果用暴力枚举,那么时间复杂度为
用DFS,时间复杂度为
O(N2)
O
(
N
2
)
暴力膜不可取。
那么我们重新分析,发现
(u,v)
(
u
,
v
)
有两种情况(对于一整棵树)
1.
u
u
,在同一棵子树上
2.
u
u
,在不同子树上
其中1的情况继续分析,会变成2的情况,只要递推下去就可以了。
那么就会想到用点分治。
我们用
Get(x,v)
G
e
t
(
x
,
v
)
来求以x为根节点,其中
d(x,fa[x])=v
d
(
x
,
f
a
[
x
]
)
=
v
中
(u,v)
(
u
,
v
)
数对的个数
inline int Get(int x,int v){
dis[x]=v;Sort[Num=1]=dis[x];
dfs(x,0);int S=0; //dfs求该子树上的节点到根节点的距离
sort(Sort+1,Sort+1+Num);
for (int i=1,j=Num;i<j;) if (Sort[i]+Sort[j]<=K) S+=j-i,i++;else j--;
return S;
}
inline void Solve(int x){
Ans+=Get(x,0);
vis[x]=1;
for (int i=lnk[x];i;i=nxt[i]){
if (vis[son[i]]) continue;
Ans-=Get(son[i],w[i]);
Root=0;Sum=Size[son[i]];
getrt(son[i],0);
Solve(Root);
}
}
为什么要
Ans−=Get(son[i],w[i])
A
n
s
−
=
G
e
t
(
s
o
n
[
i
]
,
w
[
i
]
)
?因为在
Get(1,0)
G
e
t
(
1
,
0
)
中
(3,5)
(
3
,
5
)
这对数已经被算进去了,
Get(3,0)
G
e
t
(
3
,
0
)
中
(3,5)
(
3
,
5
)
又算了一遍,所以要用
Get(son[i],w[i])
G
e
t
(
s
o
n
[
i
]
,
w
[
i
]
)
减一遍,至于为什么
d(son[i],fa[son[i]])=w[i]
d
(
s
o
n
[
i
]
,
f
a
[
s
o
n
[
i
]
]
)
=
w
[
i
]
,那是因为
(u,v)
(
u
,
v
)
数对可能原来计算时没有算进去,当
Get(son[i],0)
G
e
t
(
s
o
n
[
i
]
,
0
)
时会被算进去,就会多减1,所以要加一道保险。
代码
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int N,K,Ans,dis[10005];
int tot,nxt[20005],lnk[10005],son[20005],w[20005];
int Root,Sum,Num,Size[10005],F[10005],Sort[10005];
bool vis[10005];
inline void add(int x,int y,int z){son[++tot]=y;w[tot]=z;nxt[tot]=lnk[x];lnk[x]=tot;} //建边
inline void getrt(int x,int fa){ //求重心
F[x]=0;Size[x]=1;
for (int i=lnk[x];i;i=nxt[i]){
if (vis[son[i]]||son[i]==fa) continue;
getrt(son[i],x);
Size[x]+=Size[son[i]];
F[x]=max(F[x],Size[son[i]]);
}
F[x]=max(F[x],Sum-Size[x]);
if (F[x]<F[Root]) Root=x;
}
inline void dfs(int x,int fa){
for (int i=lnk[x];i;i=nxt[i]){
if (son[i]==fa||vis[son[i]]) continue;
dis[son[i]]=dis[x]+w[i];
Sort[++Num]=dis[son[i]];
dfs(son[i],x);
}
}
inline int Get(int x,int v){
dis[x]=v;Sort[Num=1]=dis[x];
dfs(x,0);int S=0;
sort(Sort+1,Sort+1+Num);
for (int i=1,j=Num;i<j;) if (Sort[i]+Sort[j]<=K) S+=j-i,i++;else j--;
//如果d(i,j)<=k,那么d(i+1,j),d(i+2,j),……,d(j-1,j)<=k,有(j-1-i+1)=j-i对数对
return S;
}
inline void Solve(int x){
Ans+=Get(x,0);
vis[x]=1;
for (int i=lnk[x];i;i=nxt[i]){
if (vis[son[i]]) continue;
Ans-=Get(son[i],w[i]);
Root=0;Sum=Size[son[i]];
getrt(son[i],0);
Solve(Root);
}
}
int main()
{
while (1){
scanf("%d%d",&N,&K);
if (N==K&&N==0) return 0;
tot=0;
memset(vis,0,sizeof vis);
memset(lnk,0,sizeof lnk);
for (int i=1;i<N;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);add(y,x,z);
}
Root=0;Ans=0;
F[0]=2e9;Sum=N;
getrt(1,0);
Solve(Root);
printf("%d\n",Ans);
}
return 0;
}