题目
题意:
给一个有根树,树的每个节点都有个权值,且都不比其父亲权值大。然后有 q(1e5) 个询问,问从点 x 开始,走遍所有与其相连,权值在 [l,r] 以内的点的个数。
分析
由于这个树权值的性质,对于每个询问,其实可以先从 x 往上跳到最大的小于等于 r 的祖先 F ,然后查询以 F 为根的子树中大于等于 l 的点的个数即可。
可以用主席树。
但因为是树上的区间值域维护,所以也可以利用 dfs序 ,用一个树状数组就行。
时间复杂度 O n lgn
代码
#include <bits/stdc++.h>
using namespace std;
#define MAXN 100005
#define MAXQ 100005
#define INF 1000000009
vector<int> graph[MAXN];
vector<int> Qid[MAXN];
unordered_map<int,int> Hash;
int n,q;
int ST[MAXN][25];
int a[MAXN];
int xx[MAXQ],L[MAXQ],R[MAXQ];
int ans[MAXQ];
int SUM,MAXS;
int cntid,Lid[MAXN],Rid[MAXN];
int c[MAXN*5];
#define lowbit(x) ((x)&(-(x)))
void ADD(int x, int v) {
for (; x<=MAXS; x+=lowbit(x)) c[x]+=v;
SUM+=v;
}
int QRY(int x)
{
int ret=0;
for (x--; x>0; x-=lowbit(x)) ret+=c[x];
return SUM-ret;
}
void dfs1(int X, int fX)
{
ST[X][0]=fX;
Lid[X]=++cntid;
for (int o: graph[X]) if (o!=fX) dfs1(o,X);
Rid[X]=cntid;
}
void getST()
{
for (int j=1; j<25; j++)
for (int i=1; i<=n; i++)
ST[i][j]=ST[ST[i][j-1]][j-1];
}
void dfs_modify(int X, int fX, int Val)
{
ADD(a[X],Val);
for (int o: graph[X]) if (o!=fX)
dfs_modify(o,X,Val);
}
void dfs2(int X, int fX, int flag)
{
flag+=Qid[X].size();
for (int o: graph[X]) if (o!=fX)
dfs2(o,X,flag);
ADD(a[X],1);
int tmps=Qid[X].size();
while (Qid[X].size()) {
ans[Qid[X].back()]=QRY(L[Qid[X].back()]);
Qid[X].pop_back();
}
flag-=tmps;
if (flag==0) dfs_modify(X,fX,-1);
}
void dfs3(int X, int fX)
{
for (int i: Qid[Lid[X]]) if (i<0) ans[-i]-=QRY(L[-i]);
ADD(a[X],1);
for (int i: Qid[Lid[X]]) if (i>0) ans[i]+=QRY(L[i]);
for (int o: graph[X]) if (o!=fX) dfs3(o,X);
}
int main()
{
scanf("%d",&n);
vector<int> tmphash;
for (int i=1,uu,vv; i<n; i++) {
scanf("%d%d",&uu,&vv);
graph[uu].push_back(vv);
graph[vv].push_back(uu);
}
for (int i=1; i<=n; i++) {
scanf("%d",&a[i]);
tmphash.push_back(a[i]);
}
scanf("%d",&q);
for (int i=1; i<=q; i++) {
scanf("%d%d%d",&xx[i],&L[i],&R[i]);
tmphash.push_back(L[i]);
tmphash.push_back(R[i]);
}
sort(tmphash.begin(),tmphash.end());
for (int i=0,num=1,mi=tmphash.size(); i<mi; i++,num+=(tmphash[i]!=tmphash[i-1])) {
Hash[tmphash[i]]=num;
MAXS=num;
}
for (int i=1; i<=n; i++) a[i]=Hash[a[i]];
for (int i=1; i<=q; i++) {
L[i]=Hash[L[i]];
R[i]=Hash[R[i]];
}
a[0]=INF;
MAXS+=10;
dfs1(1,0);
getST();
//scanf("%d",&q);
for (int i=1,tmp; i<=q; i++) {
//scanf("%d%d%d",&xx[i],&L[i],&R[i]);
if (a[xx[i]]<L[i] || a[xx[i]]>R[i]) {
ans[i]=0;
continue;
}
tmp=xx[i];
for (int j=24; j>=0; j--)
if (ST[tmp][j] && a[ST[tmp][j]]<=R[i])
tmp=ST[tmp][j];
if (a[tmp]<=R[i]) {
//Qid[tmp].push(i);
Qid[Lid[tmp]].push_back(-i);
Qid[Rid[tmp]].push_back(i);
}
else
ans[i]=0;
}
dfs3(1,0);
for (int i=1; i<=q; i++) printf("%d\n",ans[i]);
return 0;
}
/*
9
1 2
1 3
2 4
2 5
5 6
6 7
3 8
3 9
10 1 9 1 1 1 1 1 1
9 8 5 3 6 4 1 2 1
*/