A tree is a connected graph that doesn’t contain any cycles.
The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.
You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u, v) are considered to be the same pair.
Input
The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.
Next n - 1 lines describe the edges as “ai bi” (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.
Output
Print a single integer — the number of distinct pairs of the tree’s vertices which have a distance of exactly k between them.
Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.
Sample test(s)
input
5 2
1 2
2 3
3 4
2 5
output
4
input
5 3
1 2
2 3
3 4
4 5
output
2
Note
In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).
解题思路:简单树形DP。
#include <cmath>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <string>
#include <vector>
#include <queue>
#include <map>
#include <set>
#include <utility>
#include <algorithm>
#include <functional>
using namespace std;
typedef long long ll;
const int maxn = 50010;
ll dp1[maxn][510];
ll dp2[maxn];
int n, k;
ll ans;
struct Edge {
int v, next;
Edge() { }
Edge(int _v, int _next) : v(_v), next(_next) { }
}edges[2*maxn];
int head[maxn], edge_sum;
void init_graph() {
edge_sum = 0;
memset(head, -1, sizeof(head));
}
void add_edge(int u, int v) {
edges[edge_sum].v = v;
edges[edge_sum].next = head[u];
head[u] = edge_sum++;
edges[edge_sum].v = u;
edges[edge_sum].next = head[v];
head[v] = edge_sum++;
}
void dfs1(int u, int fa) {
dp1[u][0] += 1;
for(int i = head[u]; i != -1; i = edges[i].next) {
int v = edges[i].v;
if(v == fa) continue;
dfs1(v, u);
for(int j = 0; j < k; ++j) {
dp1[u][j+1] += dp1[v][j];
}
}
return ;
}
int par[maxn][510];
void dfs2(int u, int fa) {
par[u][0] = u;
dp2[u] = dp1[u][k];
if(fa != -1) {
for(int i = 1; i <= k; ++i) {
par[u][i] = par[fa][i-1];
}
for(int i = 1; i <= k; ++i) {
if(par[u][i] == -1) break;
dp2[u] += dp1[par[u][i]][k-i];
if(k-i-1 >= 0) {
dp2[u] -= dp1[par[u][i-1]][k-i-1];
}
}
}
ans += dp2[u];
for(int i = head[u]; i != -1; i = edges[i].next) {
int v = edges[i].v;
if(v == fa) continue;
dfs2(v, u);
}
return ;
}
int main() {
//freopen("aa.in", "r", stdin);
int u, v;
scanf("%d %d", &n, &k);
init_graph();
for(int i = 1; i < n; ++i) {
scanf("%d %d", &u, &v);
add_edge(u, v);
}
ans = 0;
dfs1(1, -1);
memset(par, -1, sizeof(par));
dfs2(1, -1);
printf("%I64d\n", ans/2);
return 0;
}