解析
算法一
定义
u
p
x
,
k
up_{x,k}
upx,k 为节点
x
x
x 从自己的颜色所在位置在返祖链上往后跳
2
k
2^k
2k 个颜色到达的节点。
可以像倍增一样的求解。
这样对于一次询问
(
s
,
t
)
(s,t)
(s,t) 我们就能求出
(
s
,
l
c
a
)
(s,lca)
(s,lca) 这一段能取到哪里了。
对于向下的情况,再处理一个
u
p
x
,
k
′
up'_{x,k}
upx,k′ 表示节点
x
x
x 从自己的颜色所在位置在返祖链上往前跳
2
k
2^k
2k 个颜色到达的节点。
然后二分每一个询问的答案,从答案开始往前跳,看能否与
(
s
,
l
c
a
)
(s,lca)
(s,lca) 相接即可判定是否合法。
时间复杂度
O
(
n
log
n
+
m
log
C
log
n
)
O(n\log n+m\log C\log n)
O(nlogn+mlogClogn)。
算法二
考虑优化后一段
(
l
c
a
,
t
)
(lca,t)
(lca,t) 的过程。
假设询问
i
i
i 在
(
s
,
l
c
a
)
(s,lca)
(s,lca) 过程中跳到了颜色
c
c
c,就在
l
c
a
lca
lca 处增加一个
(
i
,
c
)
(i,c)
(i,c) 的元素,在
t
t
t 处打一个
(
i
)
(i)
(i) 标记。
考虑我们
d
f
s
dfs
dfs 过程中需要维护什么:
- 插入二元组 ( i , c ) (i,c) (i,c)
- 如果当前节点颜色为 c c c,收集器上的下一个颜色为 s u f suf suf,就使所有 ( i , c ) → ( i , s u f ) (i,c)\to(i,suf) (i,c)→(i,suf)。
- 查询当前的 i i i 元素的特征值。
- 撤销当前dfs的影响。
这个东西可以用可撤销并查集维护。
总复杂度
O
(
n
log
n
+
m
log
n
)
O(n\log n+m\log n)
O(nlogn+mlogn)
代码
写的是算法二。
#include<bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define debug(...) fprintf(stderr,__VA_ARGS__)
#define ok debug("OK\n")
using namespace std;
const int N=2e6+100;
const int M=50050;
const int mod=1e9+7;
const double eps=1e-9;
inline ll read() {
ll x(0),f(1);char c=getchar();
while(!isdigit(c)) {if(c=='-')f=-1;c=getchar();}
while(isdigit(c)) {x=(x<<1)+(x<<3)+c-'0';c=getchar();}
return x*f;
}
int n,m,C,Mx;
struct node{
int to,nxt;
}e[N<<1];
int fi[N],cnt;
inline void addline(int x,int y){
e[++cnt]=(node){y,fi[x]};fi[x]=cnt;
return;
}
int p[N],col[N];
int pre[N],up[N][20],pl[N][20],dep[N],suf[N];
void dfs(int x,int f){
dep[x]=dep[f]+1;
pl[x][0]=f;
for(int k=1;pl[x][k-1];k++) pl[x][k]=pl[pl[x][k-1]][k-1];
up[x][0]=pre[suf[col[x]]];
for(int k=1;up[x][k-1];k++) up[x][k]=up[up[x][k-1]][k-1];
int ori=pre[col[x]];pre[col[x]]=x;
for(int i=fi[x];~i;i=e[i].nxt){
int to=e[i].to;
if(to==f) continue;
dfs(to,x);
}
pre[col[x]]=ori;
return;
}
inline int Lca(int x,int y){
if(dep[x]<dep[y]) swap(x,y);
for(int k=17;k>=0;k--){
if(dep[pl[x][k]]<dep[y]) continue;
x=pl[x][k];
}
if(x==y) return x;
for(int k=17;k>=0;k--){
if(pl[x][k]==pl[y][k]) continue;
x=pl[x][k];y=pl[y][k];
}
return pl[x][0];
}
struct query{
int s,t,lca,id;
};
vector<query>v[N];
inline int jump(int x,int top){//return color;
for(int k=17;k>=0;k--){
if(dep[up[x][k]]<dep[top]) continue;
x=up[x][k];
}
return suf[col[x]];
}
struct add{
int id,c;
};
vector<add>ad[N];
vector<int>q[N];
void solve1(int x,int f){
int ori=pre[col[x]];pre[col[x]]=x;
for(query o:v[x]){
int s=o.s,t=o.t,lca=o.lca,id=o.id;
s=pre[p[1]];
if(dep[s]<dep[lca]) ad[lca].push_back((add){id,p[1]});
else ad[lca].push_back((add){id,jump(s,lca)});
q[t].push_back(id);
}
for(int i=fi[x];~i;i=e[i].nxt){
int to=e[i].to;
if(to==f) continue;
solve1(to,x);
}
pre[col[x]]=ori;
return;
}
int mx[N],fa[N],siz[N];
struct ope{
int op;//1:fa 2:siz 3:mx 4:bel
int id,ori;
}zhan[N<<3];
int top,nam[N],bel[N],tot;
int find(int x){
return fa[x]==x?x:find(fa[x]);
}
inline int New(int val){
++tot;fa[tot]=tot;siz[tot]=1;mx[tot]=val;
return tot;
}
void merge(int x,int y){
x=find(x);y=find(y);
if(siz[x]>siz[y]) swap(x,y);
zhan[++top]=(ope){1,x,fa[x]};fa[x]=y;
zhan[++top]=(ope){2,y,siz[y]};siz[y]+=siz[x];
zhan[++top]=(ope){3,y,mx[y]};mx[y]=max(mx[y],mx[x]);
return;
}
void del(int tim){
while(top!=tim){
if(zhan[top].op==1) fa[zhan[top].id]=zhan[top].ori;
else if(zhan[top].op==2) siz[zhan[top].id]=zhan[top].ori;
else if(zhan[top].op==3) mx[zhan[top].id]=zhan[top].ori;
else if(zhan[top].op==4) bel[zhan[top].id]=zhan[top].ori;
top--;
}
return;
}
int ans[N];
int rk[N];
void solve2(int x,int f){
int ori=top;
for(add o:ad[x]){
int id=o.id,c=o.c,now=New(0);
nam[id]=now;
merge(now,bel[c]);
//ans[id]=rk[c];
}
//assert(mx[find(bel[suf[col[x]]])]==rk[suf[col[x]]]);
//assert(mx[find(bel[col[x]])]==rk[col[x]]);
merge(bel[col[x]],bel[suf[col[x]]]);
//if(mx[find(bel[col[x]])]!=rk[col[x]]+1){
// debug("%d %d\n",mx[find(bel[col[x]])],rk[col[x]]);exit(0);
//}
//assert(mx[find(bel[col[x]])]==rk[col[x]]+1);
zhan[++top]=(ope){4,col[x],bel[col[x]]};bel[col[x]]=New(rk[col[x]]);
for(int id:q[x]){
int o=nam[id];
ans[id]=mx[find(o)];
}
for(int i=fi[x];~i;i=e[i].nxt){
int to=e[i].to;
if(to==f) continue;
solve2(to,x);
}
del(ori);
}
int main() {
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
#endif
memset(fi,-1,sizeof(fi));cnt=-1;
n=read();Mx=read();C=read();
for(int i=1;i<=C;i++) p[i]=read();
for(int i=1;i<=C;i++) suf[p[i]]=p[i+1],rk[p[i]]=i;
rk[0]=C+1;
for(int i=0;i<=Mx;i++) bel[i]=New(rk[i]);
for(int i=1;i<=n;i++) col[i]=read();
for(int i=1;i<n;i++){
int x=read(),y=read();
addline(x,y);addline(y,x);
}
dfs(1,0);
m=read();
for(int i=1;i<=m;i++){
int s=read(),t=read(),lca=Lca(s,t);
v[s].push_back((query){s,t,lca,i});
}
solve1(1,0);
solve2(1,0);
for(int i=1;i<=m;i++) printf("%d\n",ans[i]-1);
return 0;
}
/*
1
3 3
1000000 2000000
0 0
*/