D. Distance in Tree
time limit per test
3 seconds
memory limit per test
512 megabytes
input
standard input
output
standard output
A tree is a connected graph that doesn't contain any cycles.
The distance between two vertices of a tree is the length (in edges) of the shortest path between these vertices.
You are given a tree with n vertices and a positive number k. Find the number of distinct pairs of the vertices which have a distance of exactly k between them. Note that pairs (v, u) and (u, v) are considered to be the same pair.
Input
The first line contains two integers n and k (1 ≤ n ≤ 50000, 1 ≤ k ≤ 500) — the number of vertices and the required distance between the vertices.
Next n - 1 lines describe the edges as "ai bi" (without the quotes) (1 ≤ ai, bi ≤ n, ai ≠ bi), where ai and bi are the vertices connected by the i-th edge. All given edges are different.
Output
Print a single integer — the number of distinct pairs of the tree's vertices which have a distance of exactly k between them.
Please do not use the %lld specifier to read or write 64-bit integers in С++. It is preferred to use the cin, cout streams or the %I64dspecifier.
Examples
input
Copy
5 2 1 2 2 3 3 4 2 5
output
Copy
4
input
Copy
5 3 1 2 2 3 3 4 4 5
output
Copy
2
Note
In the first sample the pairs of vertexes at distance 2 from each other are (1, 3), (1, 5), (3, 5) and (2, 4).
一棵树,再给一个k,然后让你求存在多少对a_i、b_i,并且dis(a_i,b_i)= k,(2,3)(3,2)算为一对
一看范围,自己就猜了一个n * k的做法,但是状态的定义有点问题,想了很长时间
f[i][j]表示以i为根节点的子树存在与i节点之间距离为j的节点个数,接着我们可以一个dfs,求出这个f[i][j]
设x为d的父亲节点,从上到下就可以很显然的得到一个转移方程
f[x][j] += f[d][j - 1](d为x的所有节点,最后求和相加)
现在我们得到了任意节点i的子树上的节点和节点i的距离为k的个数,你会发现还少一种答案,就是这种我们只得到了绿色部分的答案,但是在他的兄弟节点的子树下面和他父亲的上面都还会有答案(即使父亲节点可能与2节点已经有答案了,但是我们姑且还是算两次,最后除以2就可以),现在就是考虑黄色部分怎么求
设x为d的父亲节点,从底向上可以得到转移方程
f[d][j] += f[x][j - 1] - f[d][j - 2] (首先可以加上f[x]][j - 1]即距离父亲节点所有距离j - 1的节点个数,然后对于d为子树的部分算重复了,还要减去f[d][j - 2] )完毕,最后计数一下f[i][k] 就可以
做完之后我又去搜了一下题解,发现了另外一种解题方式,其实这种题是一种长链剖分的模板题目 ,有兴趣的可以去看看
#include <bits/stdc++.h>
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef double db;
int xx[4] = {1,-1,0,0};
int yy[4] = {0,0,1,-1};
const double eps = 1e-9;
typedef pair<int,int> P;
const int maxn = 5e4 + 5000;
const ll mod = 1e9 + 7;
inline int sign(db a) {
return a < -eps ? -1 : a > eps;
}
inline int cmp(db a,db b) {
return sign(a - b);
}
ll mul(ll a,ll b,ll c) {
ll res = 1;
while(b) {
if(b & 1) res *= a,res %= c;
a *= a,a %= c,b >>= 1;
}
return res;
}
ll phi(ll x) {
ll res = x;
for(ll i = 2; i * i <= x; i++) {
if(x % i == 0) res = res / i * (i - 1);
while(x % i == 0) x /= i;
}
if(x > 1) res = res / x * (x - 1);
return res;
}
int fa[maxn];
int Find(int x) {
if(x != fa[x]) return fa[x] = Find(fa[x]);
return fa[x];
}
ll c,n,k;
ll h,s,m;
vector<int>v[maxn];
ll f[maxn][510];
int deep[maxn];
void dfs1(int x,int fa) {
for(auto d:v[x]) {
if(d == fa) continue;
dfs1(d,x);
for(int i = 1; i <= k; i++)
f[x][i] += f[d][i - 1];
}
}
void dfs2(int x,int fa) {
for(auto d:v[x]) {
if(d == fa) continue;
for(int i = k;i >= 1;i--)
f[d][i] += f[x][i - 1] - f[d][i - 2];
dfs2(d,x);
}
}
int main() {
ios::sync_with_stdio(false);
while(cin >> n >> k) {
for(int i = 1; i < n; i++) {
int u,vv;
cin >> u >> vv;
v[vv].push_back(u);
v[u].push_back(vv);
}
for(int i = 1; i <= n; i++)
f[i][0] = 1;
dfs1(1,0);
dfs2(1,0);
ll ans = 0;
for(int i = 1; i <= n; i++) {
ans += f[i][k];
}
ans /= 2;
cout << ans << endl;
}
cerr << "time: " << (long long)clock() * 1000 / CLOCKS_PER_SEC << " ms" << endl;
return 0;
}