传送门:【HDU】4918 Query on the subtree
题目分析:
首先,简化问题。
1.求一次到点u的距离不超过d的点的个数。很容易,一次O(NlogN)的点分治便可以完成。
2.多次进行操作1。此时不能每次都O(NlogN)了,太慢了。我们考虑到对于点分治,树的重心一共有logN层,第一层为整棵树的重心,第二层为第一层重心的子树的重心,以此类推,每次至少分成两个大小差不多的子树,所以一共有logN层。而且,对于一个点,他最多只属于logN个子树,也就是最多只属于logN个重心。所以我们可以预处理出每个点所属于的重心以及到这些重心的距离,以每个重心建树状数组,每个点按照到重心的距离插入到树状数组中,然后每次查询到u距离不超过d的点的个数就通过树状数组求前缀和得到。
假设一个重心x到u的距离为dis,那么便统计到重心x距离不超过d-dis的点的个数,这个过程我们称之为“借力”,本身能力有限,所以需要借助x的影响力。因为如果这个重心被u借力了,那么这个重心的子重心一定也被借力,由于相邻被借力的两个重心x、y所统计的点会有重复,所以我们需要去重。去重的话我们就通过对每个节点再开一个v对x的树状数组,这个树状数组的意义为:重心x的子树v的重心为y时,子树v中每个点到x的距离为下标建立的树状数组。因为重心x与重心y交集的部分,重心x包括的部分重心y一定包括,所以统计的时候减去v对x的树状数组中距x不超过d-dis的点的个数即可。访问u所属与的所有重心,挨个借力,同时去重,便能得到距离u不超过d的点的个数。因为重心最多logN层,每个树状数组最多N个点,logN复杂度的统计,所以每次查询复杂度O(logN*logN)。
我们最多为每个节点开2个树状数组,而且每一层所有树状数组的大小相加不超过N,所以树状数组的占用空间为O(2NlogN)。
3.原问题:每个点有权值,一共两种操作:(1)询问到u距离不超过d的所有点的权值和为多少;(2)将一个点的权值更改为d。
这个就是在上面的基础上稍做扩充。预处理的时候插入树状数组的就是该点的权值,查询依旧是统计前缀和。修改点权值的时候,便是和查询一样,在u距重心x距离d的位置在x的树状数组中修改u的权值,同时修改u属于重心x的子树v的v对x的树状数组中相同位置的值。复杂度和查询一样为O(logN*logN)。
算法总复杂度:O(NlogN*logN)。
--------------------------------分割线---------------------------------
至此这道题便成功解决了。
这里说一下我额外的感受:
假设重心x的其中一棵子树v的重心为y,那么我们认为x是y的父节点,y是x的子节点。
在这一假设下,所有的重心便构成了一棵树,这棵树上每个重心到根重心(整棵子树的重心)的距离不超过logN,同理两个重心之间的路径长度不会超过2logN。
我们可以形象的称这一棵树为“重心树”~
正是在重心树上任意一点到根结点的路径长度不超过logN的性质下,算法的效率得到了保障~
我相信这棵树一定还会有更多的好性质待发现。(比如将树状数组换成别的数据结构如平衡树什么的)
代码如下:
#include <vector>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std ;
#define rep( i , a , b ) for ( int i = ( a ) ; i < ( b ) ; ++ i )
#define For( i , a , b ) for ( int i = ( a ) ; i <= ( b ) ; ++ i )
#define rev( i , a , b ) for ( int i = ( a ) ; i >= ( b ) ; -- i )
#define clr( a , x ) memset ( a , x , sizeof a )
const int MAXN = 100005 ;
const int MAXE = 200005 ;
const int MAXNODE = 4000005 ;
const int INF = 0x3f3f3f3f ;
struct Node {
int root , subroot , dis , n ;
Node () {}
Node ( int root , int subroot , int dis , int n ) : root ( root ) , subroot ( subroot ) , dis ( dis ) , n ( n ) {}
} ;
struct Edge {
int v , n ;
Edge () {}
Edge ( int v , int n ) : v ( v ) , n ( n ) {}
} ;
struct Tree {
int n ;
vector < int > T ;
void init ( int size ) {
T.clear () ;
n = size ;
For ( i , 0 , n ) T.push_back ( 0 ) ;
}
void add ( int x , int v ) {
for ( int i = x ; i <= n ; i += i & -i ) T[i] += v ;
}
int sum ( int x , int ans = 0 ) {
if ( x > n ) x = n ;
for ( int i = x ; i > 0 ; i -= i & -i ) ans += T[i] ;
return ans ;
}
} ;
Tree T[MAXN << 1] ;
Edge E[MAXE] ;
Node N[MAXNODE] ;
int H[MAXN] , cntE ;
int node[MAXN] , cntN ;
int Q[MAXN] , head , tail ;
int val[MAXN] ;
int vis[MAXN] , Time ;
int dep[MAXN] ;
int siz[MAXN] ;
int pre[MAXN] ;
int idx[MAXN] ;
int cur ;
int n , m ;
int tot_size ;
void clear () {
cur = 0 ;
++ Time ;
cntN = cntE = 0 ;
clr ( H , -1 ) ;
clr ( node , -1 ) ;
}
void addedge ( int u , int v ) {
E[cntE] = Edge ( v , H[u] ) ;
H[u] = cntE ++ ;
}
void addnode ( int u , int root , int subroot , int dis ) {
N[cntN] = Node ( root , subroot , dis , node[u] ) ;
node[u] = cntN ++ ;
}
int get_root ( int s ) {
head = tail = 0 ;
Q[tail ++] = s ;
pre[s] = 0 ;
while ( head != tail ) {
int u = Q[head ++] ;
for ( int i = H[u] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if ( v == pre[u] || vis[v] == Time ) continue ;
pre[v] = u ;
Q[tail ++] = v ;
}
}
tot_size = tail ;
int root = s , root_cnt = INF ;
while ( tail ){
int u = Q[-- tail] ;
siz[u] = 1 ;
int cnt = 0 ;
for ( int i = H[u] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if ( v == pre[u] || vis[v] == Time ) continue ;
siz[u] += siz[v] ;
if ( siz[v] > cnt ) cnt = siz[v] ;
}
cnt = max ( cnt , tot_size - siz[u] ) ;
if ( cnt < root_cnt ) {
root = u ;
root_cnt = cnt ;
}
}
return root ;
}
void calc ( int s , int root , int subroot ) {
head = tail = 0 ;
Q[tail ++] = s ;
pre[s] = 0 ;
dep[s] = 2 ;//dep[root] = 1
while ( head != tail ) {
int u = Q[head ++] ;
T[root].add ( dep[u] , val[u] ) ;
T[subroot].add ( dep[u] , val[u] ) ;
addnode ( u , root , subroot , dep[u] ) ;
for ( int i = H[u] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if ( v == pre[u] || vis[v] == Time ) continue ;
pre[v] = u ;
dep[v] = dep[u] + 1 ;
Q[tail ++] = v ;
}
}
}
void divide ( int u ) {
int root = get_root ( u ) ;
vis[root] = Time ;
idx[root] = ++ cur ;
T[cur].init ( tot_size ) ;
T[cur].add ( 1 , val[root] ) ;
addnode ( root , idx[root] , 0 , 1 ) ;
for ( int i = H[root] ; ~i ; i = E[i].n ) {
int v = E[i].v ;
if ( vis[v] == Time ) continue ;
T[++ cur].init ( siz[v] + 1 ) ;
calc ( v , idx[root] , cur ) ;
}
for ( int i = H[root] ; ~i ; i = E[i].n ) if ( vis[E[i].v] != Time ) divide ( E[i].v ) ;
}
void solve () {
clear () ;
char op[5] ;
int u , v , d ;
For ( i , 1 , n ) scanf ( "%d" , &val[i] ) ;
rep ( i , 1 , n ) {
scanf ( "%d%d" , &u , &v ) ;
addedge ( u , v ) ;
addedge ( v , u ) ;
}
divide ( 1 ) ;
//For ( i , 1 , cur ) printf ( "%d\n" , T[i].sum ( T[i].n ) ) ;
while ( m -- ) {
scanf ( "%s%d%d" , op , &u , &d ) ;
if ( op[0] == '?' ) {
int ans = 0 ;
for ( int i = node[u] ; ~i ; i = N[i].n ) {
int root = N[i].root , subroot = N[i].subroot , dis = N[i].dis - 1 ;
ans += T[root].sum ( d - dis + 1 ) ;
if ( subroot ) ans -= T[subroot].sum ( d - dis + 1 ) ;
}
printf ( "%d\n" , ans ) ;
} else {
for ( int i = node[u] ; ~i ; i = N[i].n ) {
int root = N[i].root , subroot = N[i].subroot , dis = N[i].dis ;
T[root].add ( dis , d - val[u] ) ;
if ( subroot ) T[subroot].add ( dis , d - val[u] ) ;
}
val[u] = d ;
}
}
}
int main () {
clr ( vis , 0 ) ;
Time = 0 ;
while ( ~scanf ( "%d%d" , &n , &m ) ) solve () ;
return 0 ;
}