Description
A tree is a connected graph that doesn’t contain any cycles.
The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.
You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u, v) are considered to be the same pair.
Input
The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.
Next n - 1 lines describe the edges as “ai bi” (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.
Output
Print a single integer — the number of distinct pairs of the tree’s vertices which have a distance of exactly k between them.
Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.
Example
Input
5 2
1 2
2 3
3 4
2 5
Output
4
Input
5 3
1 2
2 3
3 4
4 5
Output
2
Note
In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).
题意:求树上有多少个距离为k的点对。
solution:点分治。 用cnt[i],tmp[i]数组来记录dep值为i出现的次数。
为什么不只用一个cnt数组?
为了避免两个点在同一条链上的情况。
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
const int N = 5e4 + 7;
struct Edge{
int nxt, to;
}e[N<<1];
int head[N], tot=0;
int vis[N], son[N], f[N], cnt[505], tmp[505];
int sum, ans, k, n, root;
void addeage(int u, int v){
e[++tot].nxt=head[u], e[tot].to=v;
head[u]=tot;
}
void getroot(int u, int fa){
son[u]=1, f[u]=0;
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( v==fa || vis[v] ) continue;
getroot(v, u);
son[u]+=son[v];
f[u]=max(f[u], son[v]);
}
f[u]=max(f[u], sum-son[u]);
if(f[u]<f[root]) root=u;
}
void getdis(int u, int fa, int dis){
if( dis<=k ){
ans+=cnt[k-dis];
++tmp[dis];
}
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( v==fa || vis[v] ) continue;
getdis(v,u,dis+1);
}
}
void cal(int u){
for ( int i=0; i<=k; i++ ) cnt[i]=0;
cnt[0]=1;
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( vis[v] ) continue;
for ( int j=0; j<=k; j++ ) tmp[j]=0;
getdis(v,v,1);
for ( int j=1; j<=k; j++ ) cnt[j]+=tmp[j];
}
}
void solve(int u){
vis[u]=1;
cal(u);
for ( int i=head[u]; i; i=e[i].nxt ){
int v=e[i].to;
if( vis[v] ) continue;
root=0, f[0]=sum=son[v];
getroot(v,-1);
solve(root);
}
}
int main(){
scanf("%d%d", &n, &k );
for ( int i=1; i<n; i++ ){
int x, y;
scanf("%d%d", &x, &y);
addeage(x,y), addeage(y,x);
}
f[0]=n, sum=n, root=0, ans=0;
getroot(1,-1);
solve(root);
printf("%d", ans );
}