Paths on the Tree
Time Limit: 6 Seconds Memory Limit: 131072 KB
Edward has a tree with n vertices conveniently labeled with 1,2,…,n.
Edward finds a pair of paths on the tree which share no more than k common vertices. Now Edward is interested in the number of such ordered pairs of paths.
Note that path from vertex a to b is the same as the path from vertex b to a. An ordered pair means (A, B) is different from (B, A) unless A is equal to B.
Input
There are multiple test cases. The first line of input contains an integer T indicating the number of test cases. For each test case:
The first line contains two integers n, k (1 ≤ n, k ≤ 88888). Each of the following n - 1 lines contains two integers ai, bi, denoting an edge between vertices ai and bi (1 ≤ ai, bi ≤ n).
The sum of values n for all the test cases does not exceed 888888.
Output
For each case, output a single integer denoting the number of ordered pairs of paths sharing no more than k vertices.
Sample Input
1
4 2
1 2
2 3
3 4
Sample Output
93
Hint
The number of path pairs that shares no common vertex is 30.
The number of path pairs that shares 1 common vertex is 44.
The number of path pairs that shares 2 common vertices is 19.
path A paths share 2 vertices with A total
1-2-3-4 1-2, 2-3, 3-4 3
1-2-3 1-2, 2-3, 2-3-4 3
2-3-4 1-2-3, 2-3, 3-4 3
1-2 1-2, 1-2-3, 1-2-3-4 3
2-3 1-2-3, 1-2-3-4, 2-3, 2-3-4 4
3-4 1-2-3-4, 2-3-4, 3-4 3
题意: 题目比较短,我就不说了。
思路: 直接算公共点小于等于k好像不大好算? 可以直接算公共点大于k的, 我们能够枚举两条路径的公共部分, 然后统计公共部分两端分叉路径的情况数就行了,也就是要计算有多少个路径对是恰好以当前枚举的路径作为公共部分的。至于怎么求,自己想想吧,这不难的。 但是直接枚举路径是N^2的, 于是我们能够用树分治,分为公共部分经过重心和不经过重心来进行讨论,于是复杂度就变为O(nlogn)了, 实现的时候会有一些细节问题…..比较麻烦,详情请看代码
代码:
#include <cstdio>
#include <iostream>
#include <cmath>
#include <algorithm>
#include <cstring>
#include <queue>
#include <vector>
#define rep(i,a,b) for(int i=(a);i<(b);++i)
#define rrep(i,b,a) for(int i=(b);i>=(a);--i)
#define clr(a,x) memset(a,(x),sizeof(a))
#define ll long long
#define lson l, m, rt<<1
#define rson m+1,r,rt<<1|1
#define mp make_pair
#define ld long double
const int maxn = 88889 << 1;
int N, K;
bool visit[maxn];
struct Node
{
int v;
int next;
}edges[maxn];
int first[maxn], ptr;
void add(int u,int v)
{
edges[ptr].v = v;
edges[ptr].next = first[u];
first[u] = ptr++;
}
void input()
{
clr(first,-1); ptr = 0;
rep(i,0,N-1) {
int u, v; scanf("%d%d",&u,&v);
add(u,v); add(v,u);
}
}
int maxTreeSize[maxn], depth[maxn], parent[maxn];
ll treeSize[maxn];
ll dpSize[maxn], dpSquareSize[maxn];
int id[maxn], idSize;
ll squareOfSize[maxn];
bool inId[maxn];
int getNode(int s)
{
int l = 0, r = 0;
id[r++] = s;
int u;
inId[s] = true;
parent[s] = -1;
while (l < r) {
u = id[l++];
for(int i = first[u]; ~i; i = edges[i].next) {
int v = edges[i].v;
if (inId[v] || visit[v]) continue;
depth[v] = depth[u] + 1;
parent[v] = i ^ 1;
id[r++] = v;
inId[v] = true;
treeSize[v] = dpSize[i ^ 1];
squareOfSize[v] = dpSquareSize[i ^ 1];
}
}
treeSize[s] = 1;
squareOfSize[s] = 0;
for(int i = first[s]; ~i; i = edges[i].next) {
if (i == parent[s]) continue;
squareOfSize[s] += dpSize[i ^ 1] * dpSize[i ^ 1];
treeSize[s] += dpSize[i ^ 1];
}
rep(i,0,r) inId[id[i]] = false;
return r;
}
void pre_init()
{
clr(parent,-1);
idSize = getNode(1);
rrep(i,idSize-1,0) {
int x = id[i];
treeSize[x] = 1; squareOfSize[x] = 0;
for(int j = first[x] ; ~j; j = edges[j].next) {
int y = edges[j].v;
if (parent[x] == j) continue;
treeSize[x] += treeSize[y];
squareOfSize[x] += treeSize[y] * treeSize[y];
}
}
rrep(i,idSize-1,0) {
int x = id[i];
if (parent[x] != -1) {
dpSize[parent[x]] = treeSize[x];
dpSquareSize[parent[x]] = squareOfSize[x];
}
for(int j = first[x]; ~j; j = edges[j].next) {
if (j == parent[x]) continue;
int y = edges[j].v;
dpSize[j] = N - treeSize[y];
dpSquareSize[j] = squareOfSize[x] + (N - treeSize[x]) * (N - treeSize[x]) - treeSize[y] * treeSize[y];
}
}
}
int findWeight(int s)
{
idSize = getNode(s);
rrep(i,idSize-1,0) {
int x = id[i];
maxTreeSize[x] = 0;
treeSize[x] = 1;
for(int j = first[x]; ~j; j = edges[j].next) {
int y = edges[j].v;
if (parent[x] == j || visit[y]) continue;
treeSize[x] += treeSize[y];
maxTreeSize[x] = std::max((ll)maxTreeSize[x], treeSize[y]);
}
maxTreeSize[x] = std::max((ll)maxTreeSize[x], idSize - treeSize[x]);
}
int result = id[0];
rep(i,1,idSize) {
int x = id[i];
if (maxTreeSize[x] < maxTreeSize[result]) result = x;
}
return result;
}
ll CountCommon()
{
int i = 0;
ll treeSizeSum = 0, squareSum = 0;
rep(i,0,idSize) {
int x = id[i];
if (depth[x] == 0) continue;
treeSizeSum += (ll) treeSize[x] * treeSize[x];
squareSum += squareOfSize[x];
}
ll result = 0;
if (depth[id[0]] == 0) ++i;
rrep(j,idSize-1,0) if (depth[id[j]] > 0) {
treeSizeSum -= (ll) treeSize[id[j]] * treeSize[id[j]];
squareSum -= squareOfSize[id[j]];
while (i < j && depth[id[i]] + depth[id[j]] + 1 <= K) {
treeSizeSum -= (ll) treeSize[id[i]] * treeSize[id[i]];
squareSum -= squareOfSize[id[i]];
++i;
}
if (i < j && depth[id[i]] + depth[id[j]] + 1 > K) {
ll k = (ll) treeSize[id[j]] * treeSize[id[j]] - squareOfSize[id[j]];
result += k * ((ll)treeSizeSum - squareSum);
}
}
return result;
}
ll CountSpecial(int u,int rt)
{
ll k = (ll) (treeSize[rt] - treeSize[u]) * (treeSize[rt] - treeSize[u])
- squareOfSize[rt] + (ll) treeSize[u] * treeSize[u];
ll result = 0;
rep(i,0,idSize) if (depth[id[i]] + 1 > K) {
int u = id[i];
result += k * ((ll)treeSize[u] * treeSize[u] - squareOfSize[u]);
}
return result;
}
ll dfs(int u)
{
depth[u] = 0; parent[u] = -1;
u = findWeight(u);
if (idSize <= K) return 0;
depth[u] = 0; parent[u] = -1;
idSize = getNode(u);
ll ans = CountCommon();
visit[u] = true;
for(int i = first[u]; ~i; i = edges[i].next) {
int v = edges[i].v;
if (visit[v]) continue;
idSize = getNode(v);
treeSize[v] = dpSize[i ^ 1];
squareOfSize[v] = dpSquareSize[i ^ 1];
ans -= CountCommon();
ans += CountSpecial(v, u);
}
for(int i = first[u]; ~i ; i = edges[i].next) {
int v = edges[i].v;
if (visit[v]) continue;
ans += dfs(v);
}
return ans;
}
void solve()
{
clr(visit,0);
pre_init();
ll ans = (ll)N * (N - 1) / 2 + N;
ans *= ans;
ans -= dfs(1);
printf("%llu\n",(unsigned long long)ans);
/*ll std_ans = 0;
rep(i,1,N+1) rep(j,i,N+1) {
std_ans += (j - i - 1);
std_ans += (ll) i * (i - 1) / 2 + i;
std_ans += (ll) (N - j + 1) * (N - j) / 2 + (N - j + 1);
}
rep(i,1,N+1) std_ans += (ll) (i - 1) * (N - i);
printf("%llu\n",std_ans);*/
}
void Getinput()
{
freopen("in.txt","w",stdout);
int T = 1; printf("%d\n",T);
while (T--) {
int n = 88888; int k = n;
printf("%d %d\n",n,k);
rep(i,1,n) printf("%d %d\n",i,i+1);
}
}
int main()
{
//Getinput(); return 0;
#ifdef ACM
freopen("in.txt","r",stdin);
//freopen("data.in","r",stdin);
//freopen("data.out","w",stdout);
#endif // ACM
int T; std::cin >> T;
rep(cas,1,T+1) {
scanf("%d%d",&N,&K);
input();
solve();
}
return 0;
}