题目大意
给定一棵
n
个结点的树(根节点为1),每条边有边权。
求每个点的子树中,满足边数在
Data Constraint
n≤1000000
题解
一个很好的思想:树上启发式合并([Tutorial] Sack (dsu on tree))。类似于树链剖分的思想,可以解决绝大多数的无修改的子树查询问题。
回到本题,定义子树中链长的最大的儿子是重儿子。然后套用上述思想。再用线段树维护对应深度下的最长路。每次查询就是枚举当前要合并的子树的一个深度,再在线段树对应区间中查询即可。
时间复杂度: O(nlogn)
SRC
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<iostream>
#include<algorithm>
using namespace std ;
#define N 1000000 + 10
typedef long long ll ;
const int MO = 998244353 ;
struct Tree {
bool tag ;
ll val ;
} T[4*N] ;
int flag[N] ;
ll Ans[N] , TAB[N] , Dist[N] , MaxDist[N] ;
int Node[N] , Next[N] , Len[N] , Head[N] , tot ;
int L[N] , R[N] , Deep[N] , Depth[N] , BigChild[N] ;
int n , MaxDeep , Cnt ;
ll ans , ret ;
int Read() {
int ret = 0 ;
char ch = getchar() ;
while ( ch < '0' || ch > '9' ) ch = getchar() ;
while ( ch >= '0' && ch <= '9' ) {
ret = ret * 10 + ch - '0' ;
ch = getchar() ;
}
return ret ;
}
void link( int u , int v , int w ) {
Node[++tot] = v ;
Next[tot] = Head[u] ;
Len[tot] = w ;
Head[u] = tot ;
}
void PreDFS( int x ) {
Deep[x] = 1 ;
int Maxv = 0 ;
for (int p = Head[x] ; p ; p = Next[p] ) {
Depth[Node[p]] = Depth[x] + 1 ;
Dist[Node[p]] = Dist[x] + Len[p] ;
PreDFS( Node[p] ) ;
Deep[x] = max( Deep[x] , Deep[Node[p]] + 1 ) ;
if ( Deep[Node[p]] > Maxv ) Maxv = Deep[Node[p]] , BigChild[x] = Node[p] ;
}
MaxDeep = max( MaxDeep , Depth[x] ) ;
}
void Find( int x , int depth ) {
if ( flag[depth] != Cnt ) MaxDist[depth] = 0 , flag[depth] = Cnt ;
MaxDist[depth] = max( MaxDist[depth] , Dist[x] ) ;
for (int p = Head[x] ; p ; p = Next[p] ) {
Find( Node[p] , depth + 1 ) ;
}
}
void Update( int v ) {
if ( !T[v].tag ) return ;
int ls = v + v , rs = v + v + 1 ;
T[ls].val = T[rs].val = -0x7FFFFFFF ;
T[ls].tag = T[rs].tag = 1 ;
T[v].tag = 0 ;
}
void Search( int v , int l , int r , int x , int y ) {
if ( x > y || x > MaxDeep || y < 1 ) return ;
if ( x < l ) x = l ;
if ( y > r ) y = r ;
if ( T[v].val <= ret ) return ;
if ( l == x && r == y ) {
ret = max( ret , T[v].val ) ;
return ;
}
Update(v) ;
int mid = (l + r) / 2 ;
if ( y <= mid ) Search( v + v , l , mid , x , y ) ;
else if ( x > mid ) Search( v + v + 1 , mid + 1 , r , x , y ) ;
else {
Search( v + v , l , mid , x , mid ) ;
Search( v + v + 1 , mid + 1 , r , mid + 1 , y ) ;
}
T[v].val = max( T[v+v].val , T[v+v+1].val ) ;
}
void Insert( int v , int l , int r , int x , ll value ) {
if ( l == x && r == x ) {
T[v].val = max( T[v].val , value ) ;
return ;
}
Update(v) ;
int mid = (l + r) / 2 ;
if ( x <= mid ) Insert( v + v , l , mid , x , value ) ;
else Insert( v + v + 1 , mid + 1 , r , x , value ) ;
T[v].val = max( T[v+v].val , T[v+v+1].val ) ;
}
void Delete( int v , int l , int r , int x , int y ) {
if ( l == x && r == y ) {
T[v].val = -0x7FFFFFFF ;
T[v].tag = 1 ;
return ;
}
Update(v) ;
int mid = (l + r) / 2 ;
if ( y <= mid ) Delete( v + v , l , mid , x , y ) ;
else if ( x > mid ) Delete( v + v + 1 , mid + 1 , r , x , y ) ;
else {
Delete( v + v , l , mid , x , mid ) ;
Delete( v + v + 1 , mid + 1 , r , mid + 1 , y ) ;
}
T[v].val = max( T[v+v].val , T[v+v+1].val ) ;
}
void DFS( int x , int keep ) {
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == BigChild[x] ) continue ;
DFS( Node[p] , 0 ) ;
}
if ( BigChild[x] ) DFS( BigChild[x] , 1 ) ;
Insert( 1 , 1 , MaxDeep , Depth[x] , Dist[x] ) ;
for (int p = Head[x] ; p ; p = Next[p] ) {
if ( Node[p] == BigChild[x] ) continue ;
++ Cnt ;
Find( Node[p] , 1 ) ;
int UP = min( Deep[Node[p]] , R[x] ) ;
for (int i = 1 ; i <= UP ; i ++ ) {
if ( flag[i] != Cnt ) break ;
ret = -0x7FFFFFFF ;
Search( 1 , 1 , MaxDeep , Depth[x] + L[x] - i , Depth[x] + R[x] - i ) ;
Ans[x] = max( Ans[x] , ret + MaxDist[i] - 2ll * Dist[x] ) ;
}
for (int i = 1 ; i <= Deep[Node[p]] ; i ++ ) {
if ( flag[i] != Cnt ) break ;
Insert( 1 , 1 , MaxDeep , Depth[x] + i , MaxDist[i] ) ;
}
}
ret = -0x7FFFFFFF ;
Search( 1 , 1 , MaxDeep , Depth[x] + L[x] , Depth[x] + R[x] ) ;
Ans[x] = max( Ans[x] , ret - Dist[x] ) ;
if ( !keep ) Delete( 1 , 1 , MaxDeep , Depth[x] , Depth[x] + Deep[x] - 1 ) ;
}
int main() {
freopen( "watchdog.in" , "r" , stdin ) ;
freopen( "watchdog.out" , "w" , stdout ) ;
scanf( "%d" , &n ) ;
for (int i = 1 ; i <= n ; i ++ ) L[i] = Read() , R[i] = Read() ;
for (int i = 2 ; i <= n ; i ++ ) {
int u = Read() ;
int c = Read() ;
link( u , i , c ) ;
}
T[1].tag = 1 ;
T[1].val = -0x7FFFFFFF ;
memset( Ans , -1 , sizeof(Ans) ) ;
Depth[1] = 1 ;
PreDFS( 1 ) ;
DFS( 1 , 0 ) ;
TAB[0] = 1 ;
for (int i = 1 ; i <= n ; i ++ ) TAB[i] = TAB[i-1] * 23333ll % MO ;
for (int i = 1 ; i <= n ; i ++ ) {
ans = ((ans + (TAB[n-i] * (Ans[i] % MO) % MO)) % MO + MO) % MO ;
}
printf( "%lld\n" , ans ) ;
return 0 ;
}
以上.