题目:http://codeforces.com/problemset/problem/161/D
D. Distance in Tree
time limit per test
3 seconds
memory limit per test
512 megabytes
input
standard input
output
standard output
A tree is a connected graph that doesn’t contain any cycles.
The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.
You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u, v) are considered to be the same pair.
Input
The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.
Next n - 1 lines describe the edges as “ai bi” (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.
Output
Print a single integer — the number of distinct pairs of the tree’s vertices which have a distance of exactly k between them.
Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.
Examples
Input
5 2
1 2
2 3
3 4
2 5
Output
4
Input
5 3
1 2
2 3
3 4
4 5
Output
2
Note
In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).
题意:
给一棵树,统计距离为k的点对数。
思路:和上一题一样,只不过统计是=k的对数;
在cal函数中引入 tmp[],cnt[],代表作到x子树时,当前层solve中之前子树《=k的dis个数,因为k<=500;O(n)计算,且不重复;
代码:
#include<iostream>
#include<stdio.h>
#include<string.h>
#include<vector>
#define N 50005
using namespace std;
int n,K,rt,sz[N];
int cnt[505],tmp[505];
vector<int> lin[N];
int vis[N],f[N];
int size;
long long ans;
int getrt(int x,int fa)
{
sz[x]=1;
f[x]=0;
for(int i=0;i<lin[x].size();i++)
{
int u=lin[x][i];
if(u==fa||vis[u]) continue;
getrt(u,x);
sz[x]+=sz[u];
f[x]=max(f[x],sz[u]);
}
f[x]=max(f[x],size-f[x]);
if(f[rt]>f[x]) rt=x;
}
void getdis(int x,int fa,int dis)
{
sz[x]=1;
if(dis<=K)
{
ans+=cnt[K-dis];
++tmp[dis];
}
for(int i=0;i<lin[x].size();i++)
{
int u=lin[x][i];
if(vis[u]||u==fa) continue;
getdis(u,x,dis+1);
sz[x]+=sz[u];
}
}
void cal(int x)
{
memset(cnt,0,sizeof(cnt));
int ret=0;
cnt[0]=1;
for(int i=0;i<lin[x].size();i++)
{
int u=lin[x][i];
if(vis[u]) continue;
memset(tmp,0,sizeof(tmp));
getdis(u,u,1);
for(int j=1;j<=K;j++)
cnt[j]+=tmp[j];
}
}
void solve(int x)
{
vis[x]=1;
cal(x);
for(int i=0;i<lin[x].size();i++)
{
int u=lin[x][i];
if(vis[u]) continue;
f[0]=size=sz[u];
getrt(u,rt=0);
solve(rt);
}
}
int main()
{
int aa,bb;
scanf("%d%d",&n,&K);
for(int i=1;i<n;i++)
{
scanf("%d%d",&aa,&bb);
lin[aa].push_back(bb);
lin[bb].push_back(aa);
}
f[0]=size=n;
getrt(1,rt=0);
solve(rt);
cout<<ans<<endl;
}