hdu 5102 The K-th Distance
The K-th Distance
Time Limit: 8000/4000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others)
Total Submission(s): 27 Accepted Submission(s): 5
Problem Description
Given a tree, which has n node in total. Define the distance between two node u and v is the number of edge on their unique route. So we can have n(n-1)/2 numbers for all the distance, then sort the numbers in ascending order. The task is to output the sum of the first K numbers.
Input
There are several cases, first is the number of cases T. (There are most twenty cases).
For each case, the first line contain two integer n and K ( 2≤n≤100000,0≤K≤min(n(n−1)/2,106) ). In following there are n-1 lines. Each line has two integer u , v. indicate that there is an edge between node u and v.
For each case, the first line contain two integer n and K ( 2≤n≤100000,0≤K≤min(n(n−1)/2,106) ). In following there are n-1 lines. Each line has two integer u , v. indicate that there is an edge between node u and v.
Output
For each case output the answer.
Sample Input
2 3 3 1 2 2 3 5 7 1 2 1 3 2 4 2 5
Sample Output
4 10
Source
思路: 因为k <= 10^6 ,假设是一条链,n = 10^5 ,长度为1的有n-1,长度为2的有n-2....
那么长度不会超过20,这个我们就可以只计算长度为k/n+70的,
直接树分治搞就好了
#pragma comment(linker,"/STACK:1024000000,1024000000") #include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<queue> #include<vector> #include<set> #include<stack> #include<map> #include<ctime> #include<bitset> #define LL long long #define maxn 100010 #define INF 0x3f3f3f3f using namespace std; int head[maxn],to[maxn*2],next1[maxn*2] ; int top,cnt[maxn],len ; int f[maxn]; bool vi[maxn] ; void Unit(int u,int v) { next1[top] = head[u] ;to[top]=v; head[u]=top++; } void find1( int u ,int fa ) { f[u] = 0 ; cnt[u] = 1 ; int v ; for( int i = head[u] ; i != -1; i = next1[i]) { v = to[i] ; if(v==fa||vi[v]) continue ; find1(v,u) ; cnt[u] += cnt[v] ; f[u] = max(cnt[v],f[u]) ; } } int minn,k,n,MAX ; void find_root1( int u ,int fa,int &root,int sum) { int tmp = max(sum-cnt[u],f[u]) ; if(tmp < minn) { minn = tmp ; root = u ; } int v; for( int i = head[u] ; i != -1; i = next1[i]) { v = to[i] ; if(v==fa||vi[v]) continue ; find_root1(v,u,root,sum) ; } } int get_root( int u ) { find1(u,-1) ; int sum = cnt[u] ; int root = u ; minn = n ; find_root1(u,-1,root,sum) ; return root ; } int que[maxn],tt; int num[maxn],ans[maxn] ; vector<int>vec; void find(int len) { for(int i = 1 ; i <= MAX ;i++)if(i+len<=MAX) { ans[i+len] += num[i] ; } } void dfs(int u,int fa,int len1) { int v ; ans[len1]++; vec.push_back(len1); find(len1) ; for( int i = head[u] ; i != -1; i = next1[i]) { v = to[i] ; if(v==fa||vi[v]||len1+1>MAX) continue ; dfs(v,u,len1+1) ; } } void count(int u) { int v ,j ; //memset(num,0,sizeof(num)); tt=0; for(int i = head[u] ; i != -1; i = next1[i]) { v = to[i] ; if(vi[v]) continue ; len=0; vec.clear(); dfs(v,u,1); for( j = 0 ; j < vec.size();j++) { if(!num[vec[j]])que[tt++]=vec[j]; num[vec[j]]++; } } for(int i = 0 ; i < tt;i++) num[que[i]]=0; } void solve(int u) { int root=get_root(u); vi[root]=true; count(root); for(int i = head[root] ; i != -1; i = next1[i]) { int v = to[i] ; if(vi[v]) continue ; solve(v) ; } } void init() { len = 0 ; memset(vi,0,sizeof(vi)) ; top=0; memset(head,-1,sizeof(head)) ; memset(num,0,sizeof(num)) ; memset(ans,0,sizeof(ans)) ; } int main() { int i ,j ,m ; int T,u,v; // freopen("in.txt","r",stdin); // freopen("out.txt","w",stdout); cin >> T ; while(T--) { scanf("%d%d",&n,&k) ; MAX=min(k/n+70,k); init(); for( i = 1 ; i < n ;i++) { scanf("%d%d",&u,&v); Unit(u,v) ; Unit(v,u) ; } solve(1); LL sum=0; for( i = 1 ; i <= MAX ;i++) { if(ans[i]>=k){ sum += (LL)k*i ; break; } k -= ans[i] ; sum += (LL)ans[i]*i ; } printf("%I64d\n",sum) ; } return 0 ; }