A tree is a connected graph that doesn't contain any cycles.
The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.
You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u, v) are considered to be the same pair.
The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.
Next n - 1 lines describe the edges as "ai bi" (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.
Print a single integer — the number of distinct pairs of the tree's vertices which have a distance of exactly k between them.
Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64d specifier.
5 2 1 2 2 3 3 4 2 5
4
5 3 1 2 2 3 3 4 4 5
2
#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<vector>
using namespace std;
#define ll long long
const int maxm = 100005;
vector<int>v[maxm];
int n, root, m, sn, tot;
ll ans;
int son[maxm], s[maxm], vis[maxm], flag[505], cnt[505];
void getroot(int k, int pre)
{
son[k] = 1, s[k] = 0;
for (int i = 0;i < v[k].size();i++)
{
if (v[k][i] == pre || vis[v[k][i]]) continue;
getroot(v[k][i], k);
son[k] += son[v[k][i]], s[k] = max(s[k], son[v[k][i]]);
}
s[k] = max(s[k], sn - s[k]);
if (s[root] > s[k]) root = k;
}
void getdeep(int k, int pre, int deep)
{
if (deep <= m)
{
ans += flag[m - deep];
cnt[deep] ++;
}
//printf("%d %d %d\n", root, k, deep);
for (int i = 0;i < v[k].size();i++)
{
if (v[k][i] == pre || vis[v[k][i]]) continue;
getdeep(v[k][i], k, deep + 1);
}
}
void query(int k)
{
memset(flag, 0, sizeof(flag));
flag[0] = 1;
for (int i = 0;i < v[k].size();i++)
{
if (vis[v[k][i]]) continue;
memset(cnt, 0, sizeof(cnt));
getdeep(v[k][i], 0, 1);
for (int j = 1;j <= m;j++)
flag[j] += cnt[j];
}
}
void dfs(int k)
{
vis[k] = 1, query(k);
for (int i = 0;i < v[k].size();i++)
{
if (vis[v[k][i]]) continue;
sn = son[v[k][i]], root = 0;
getroot(v[k][i], 0), dfs(root);
}
}
int main()
{
int i, j, sum, x, y;
scanf("%d%d", &n, &m);
for (i = 1;i < n;i++)
{
scanf("%d%d", &x, &y);
v[x].push_back(y), v[y].push_back(x);
}
s[0] = 100000000, root = 0, sn = n;
getroot(1, 0), dfs(root);
printf("%lld\n", ans);
return 0;
}