传送门:【ZJU】3863 Paths on the Tree
题意:给一棵树,问树上有多少个路径对有不超过K个公共节点的,路径a->b和b->a等价,路径对(A,B)和(B,A)只有当A和B是同一条路径时相同。
分析:反过来考虑,考虑有超过K+1个公共节点的路径对数。我们考虑重叠的路径部分,这个可以用树分治来搞,然后路径对的两端延伸出去的部分不重叠,我们要预处理出这个部分。最后就是当前枚举的子树和之前子树乘一乘。
trick:路径对总数是会爆long long的,要用unsigned long long。
后记:太懒了。。写不动题解啊,粗略的写了下思路,细节可以自己想想,还不理解可以QQ找我。。
代码如下:
#include <stdio.h>
#include <string.h>
#include <string>
#include <math.h>
#include <map>
#include <algorithm>
using namespace std ;
typedef long long LL ;
typedef unsigned long long ULL ;
#define rep( i , a , b ) for ( int i = ( a ) ; i < ( b ) ; ++ i )
#define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define clr( a , x ) memset ( a , x , sizeof a )
const int MAXN = 100005 ;
const int MAXE = 400005 ;
struct Edge {
int v , c , n ;
Edge () {}
Edge ( int v , int n ) : v ( v ) , n ( n ) {}
} ;
struct Node {
int d ;
ULL x ;
Node () {}
Node ( int d , ULL x ) : d ( d ) , x ( x ) {}
} ;
Edge E[MAXE] ;
int H[MAXN] , cntE ;
int vis[MAXN] , Time ;
int siz[MAXN] ;
int pre[MAXN] ;
int dep[MAXN] ;
int max_dis , tree_size ;
int Q[MAXN] , head , tail ;
Node S[MAXN] ;
int top ;
ULL c[MAXN] , cc[MAXN] ;
int n , k ;
ULL ans , tmpc0 ;
void clear () {
ans = 0 ;
cntE = 0 ;
++ Time ;
clr ( H , -1 ) ;
}
void addedge ( int u , int v ) {
E[cntE] = Edge ( v , H[u] ) ;
H[u] = cntE ++ ;
}
int get_root ( int s ) {
head = tail = 0 ;
Q[tail ++] = s ;
dep[s] = 0 ;
pre[s] = 0 ;
while ( head != tail ) {
int u = Q[head ++] ;
for ( int i = H[u] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if ( v == pre[u] || vis[v] == Time ) continue ;
pre[v] = u ;
dep[v] = dep[u] + 1 ;
Q[tail ++] = v ;
}
}
max_dis = dep[Q[tail - 1]] ;
int root = s , root_siz = MAXN , max_siz = tail ;
tree_size = tail ;
while ( head ) {
int u = Q[-- head] ;
int cnt = 0 ;
siz[u] = 1 ;
for ( int i = H[u] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if ( v == pre[u] || vis[v] == Time ) continue ;
siz[u] += siz[v] ;
if ( siz[v] > cnt ) cnt = siz[v] ;
}
cnt = max ( cnt , max_siz - siz[u] ) ;
if ( cnt < root_siz ) {
root_siz = cnt ;
root = u ;
}
}
return root ;
}
void calc ( int s , int s_size , int root ) {
head = tail = 0 ;
Q[tail ++] = s ;
dep[s] = 1 ;
pre[s] = root ;
top = 0 ;
ULL c0 = tmpc0 - ( ULL ) 2 * s_size * ( n - s_size ) ;
while ( head != tail ) {
int u = Q[head ++] ;
ULL x = 1 , y = 0 ;
for ( int i = H[u] ; ~i ; i = E[i].n ) {
int v = E[i].v , c = E[i].c ;
if ( v == pre[u] ) continue ;
x += E[i].c ;
y += ( ULL ) E[i].c * E[i].c ;
}
//printf ( "%I64u %I64u %d %d\n" , x , y , dep[u] , u ) ;
S[top ++] = Node ( dep[u] , x * x - y ) ;
for ( int i = H[u] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if ( v == pre[u] || vis[v] == Time ) continue ;
pre[v] = u ;
dep[v] = dep[u] + 1 ;
Q[tail ++] = v ;
}
}
rep ( i , 0 , top ) {
int idx = max ( 0 , k - S[i].d - 1 ) ;
if ( idx > max_dis ) continue ;
if ( !idx ) ans += c0 * S[i].x ;
ans += cc[idx] * S[i].x ;
}
rep ( i , 0 , top ) c[S[i].d] += S[i].x ;
cc[max_dis] = c[max_dis] ;
rev ( i , max_dis , 1 ) cc[i - 1] = cc[i] + c[i - 1] ;
}
void dfs ( int u ) {
int root = get_root ( u ) ;
if ( tree_size < k ) return ;//no satisfied path
vis[root] = Time ;
memset ( c , 0 , sizeof ( c[0] ) * ( max_dis + 1 ) ) ;
memset ( cc , 0 , sizeof ( cc[0] ) * ( max_dis + 1 ) ) ;
//For ( i , 0 , max_dis ) c[i] = cc[i] = 0 ;
ULL y = 0 ;
for ( int i = H[root] ; ~i ; i = E[i].n ) y += ( ULL ) E[i].c * E[i].c ;
tmpc0 = ( ULL ) n * n - y ;
for ( int i = H[root] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if( vis[v] == Time ) continue ;
calc ( v , E[i].c , root ) ;
}
for ( int i = H[root] ; ~i ; i = E[i].n ) if ( vis[E[i].v] != Time ) dfs ( E[i].v ) ;
}
void pre_dfs ( int u , int f ) {
siz[u] = 1 ;
for ( int i = H[u] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if ( v == f ) continue ;
pre_dfs ( v , u ) ;
E[i].c = siz[v] ;
siz[u] += siz[v] ;
}
for ( int i = H[u] ; ~i ; i = E[i].n ) if ( E[i].v == f ) {
E[i].c = n - siz[u] ;
break ;
}
}
void solve () {
int u , v ;
clear () ;
scanf ( "%d%d" , &n , &k ) ;
++ k ;
rep ( i , 1 , n ) {
scanf ( "%d%d" , &u , &v ) ;
addedge ( u , v ) ;
addedge ( v , u ) ;
}
ULL x = ( ULL ) n * ( n + 1 ) / 2 ;
pre_dfs ( 1 , 0 ) ;
dfs ( 1 ) ;
printf ( "%llu\n" , x * x - ans ) ;
}
int main () {
int T ;
Time = 0 ;
clr ( vis , 0 ) ;
scanf ( "%d" , &T ) ;
For ( i , 1 , T ) solve () ;
return 0 ;
}