题目大意
有一颗N个节点的树和M个询问,每个点有一种颜色。每次询问u到v路径上把颜色s和颜色t当作同一种颜色后路径上不同颜色的数目。
n<=5*10^4,m<=10^5
树上莫队
注意到这题符合莫队算法特征。
于是直接树上莫队即可,用num[x]表示颜色x出现的次数,那么对于把颜色s和t当作同一种颜色只需要看num[s]和num[t]是否都大于0,注意考虑s=t的情况。
#include<cstdio>
#include<cmath>
#include<algorithm>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
const int maxn=50000+10,maxm=100000+10;
struct dong{
int u,v,l,r,s,t,id;
bool p;
};
int belong[maxn*2],h[maxn],go[maxn*2],next[maxn*2],a[maxn*2],fi[maxn],la[maxn],num[maxn],co[maxn],d[maxn],ans[maxm];
bool bz[maxn];
int f[maxn][20];
dong ask[maxm];
int i,j,k,l,r,s,t,n,m,u,v,w,tot,top,c,now;
void add(int x,int y){
go[++tot]=y;
next[tot]=h[x];
h[x]=tot;
}
void dfs(int x,int y){
d[x]=d[y]+1;
f[x][0]=y;
a[++top]=x;
fi[x]=top;
int t=h[x];
while (t){
if (go[t]!=y) dfs(go[t],x);
t=next[t];
}
a[++top]=x;
la[x]=top;
}
bool cmp(dong a,dong b){
if (belong[a.l]<belong[b.l]) return 1;
else if (belong[a.l]==belong[b.l]&&a.r<b.r) return 1;
else return 0;
}
void change(int x){
if (bz[x]){
num[co[x]]--;
if (num[co[x]]==0) now--;
}
else{
num[co[x]]++;
if (num[co[x]]==1) now++;
}
bz[x]^=1;
}
int lca(int x,int y){
int j;
if (d[x]<d[y]) swap(x,y);
if (d[x]!=d[y]){
j=floor(log(d[x]-d[y])/log(2));
while (j>=0){
if (d[f[x][j]]>d[y]) x=f[x][j];
j--;
}
x=f[x][0];
}
if (x==y) return x;
j=floor(log(d[x])/log(2));
while (j>=0){
if (f[x][j]!=f[y][j]){
x=f[x][j];
y=f[y][j];
}
j--;
}
return f[x][0];
}
int main(){
scanf("%d%d",&n,&m);
fo(i,1,n) scanf("%d",&co[i]);
fo(i,1,n){
scanf("%d%d",&j,&k);
if (j&&k) add(j,k),add(k,j);
}
dfs(1,0);
fo(j,1,floor(log(n)/log(2)))
fo(i,1,n)
f[i][j]=f[f[i][j-1]][j-1];
fo(i,1,m){
scanf("%d%d",&j,&k);
ask[i].u=j;ask[i].v=k;
if (fi[j]>fi[k]) swap(j,k);
if (fi[k]<la[j]) ask[i].l=fi[j],ask[i].r=fi[k],ask[i].p=0;else ask[i].l=la[j],ask[i].r=fi[k],ask[i].p=1;
scanf("%d%d",&ask[i].s,&ask[i].t);
ask[i].id=i;
}
c=floor(sqrt(n*2))+1;
fo(i,1,n*2) belong[i]=(i-1)/c+1;
sort(ask+1,ask+m+1,cmp);
l=r=1;
change(a[1]);
fo(i,1,m){
while (l<ask[i].l){
change(a[l]);
l++;
}
while (l>ask[i].l){
l--;
change(a[l]);
}
while (r<ask[i].r){
r++;
change(a[r]);
}
while (r>ask[i].r){
change(a[r]);
r--;
}
if (ask[i].p){
u=ask[i].u;v=ask[i].v;
w=lca(u,v);
change(w);
}
j=now;
s=ask[i].s;t=ask[i].t;
if (s!=t&&num[s]&&num[t]) j--;
ans[ask[i].id]=j;
if (ask[i].p) change(w);
}
fo(i,1,m) printf("%d\n",ans[i]);
}
分块大法好
由于我一开始是不会莫队的,该题目有一个部分分保证树是一条链,于是我们可以使用分块大法好。
用cnt[i,j]表示第i块到第j块不同的颜色数,sum[i,j]表示颜色i在前j块出现的次数。显然这两个都可以在o(n^1.5)预处理。
那么对于询问j~k怎么做呢?显然如果j和k之间没有跨越整一块可以直接暴力了。否则,先找到j所位于的块l,k所位于的块r,然后ans就为cnt[l+1][r-1]。接下来对残余部分进行统计。比如统计到一个颜色x,那么我们需要知道x在l+1~r-1块里有没有出现,这个可以使用sum,还有在残余部分有没有出现,这个可以用一个桶。最后判一下s和t就行了。
#include<cstdio>
#include<algorithm>
#include<cmath>
#define fo(i,a,b) for(i=a;i<=b;i++)
using namespace std;
const int n1=100+10,n2=50000+10;
int father[n1],d[n1],co[n2],num[n2];
int h[n2],go[n2*2],next[n2*2],a[n2],belong[n2],ru[n2],dfn[n2];
int cnt[250][250],sum[n2][250];
int i,j,k,l,r,s,t,n,m,u,v,tot,top,ans,c;
bool czy;
void add(int x,int y){
ru[x]++;ru[y]++;
go[++tot]=y;
next[tot]=h[x];
h[x]=tot;
}
void init(){
scanf("%d%d",&n,&m);
fo(i,1,n) scanf("%d",&co[i]);
fo(i,1,n){
scanf("%d%d",&j,&k);
if (j&&k){
add(j,k);
add(k,j);
}
}
}
void build(int x){
int t=h[x];
while (t){
if (!d[go[t]]){
father[go[t]]=x;
d[go[t]]=d[x]+1;
build(go[t]);
}
t=next[t];
}
}
void solve1(){
d[1]=1;
build(1);
while (m--){
scanf("%d%d%d%d",&u,&v,&j,&k);
while (u!=v){
if (d[u]>d[v]){
num[co[u]]++;
u=father[u];
}
else{
num[co[v]]++;
v=father[v];
}
}
num[co[u]]++;
ans=0;
fo(i,1,n)
if (num[i]) ans++;
if (j!=k&&num[j]&&num[k]) ans--;
printf("%d\n",ans);
fo(i,1,n) num[i]=0;
}
czy=1;
}
void dfs(int x){
a[++top]=co[x];
dfn[x]=top;
int t=h[x];
while (t){
if (!dfn[go[t]]) dfs(go[t]);
t=next[t];
}
}
void solve2(){
if (czy) return;
c=floor(sqrt(n))+1;
fo(i,1,n)
if (ru[i]==2){
k=i;
break;
}
dfs(k);
fo(i,1,n) belong[i]=(i-1)/c+1;
fo(i,1,n) sum[a[i]][belong[i]]++;
fo(i,1,n)
fo(j,1,belong[n])
sum[i][j]+=sum[i][j-1];
fo(i,1,belong[n]){
if (i==50){
t=t;
}
fo(j,(i-1)*c+1,n){
if (!num[a[j]]) cnt[i][belong[j]]++;
num[a[j]]++;
}
fo(j,1,n) num[j]=0;
}
fo(i,1,belong[n])
fo(j,i+1,belong[n])
cnt[i][j]+=cnt[i][j-1];
while (m--){
ans=0;
scanf("%d%d%d%d",&j,&k,&s,&t);
j=dfn[j];k=dfn[k];
if (j>k) swap(j,k);
l=belong[j];r=belong[k];
if (r-l<=1){
fo(i,j,k){
if (!num[a[i]]) ans++;
num[a[i]]++;
}
if (s!=t&&num[s]&&num[t]) ans--;
printf("%d\n",ans);
fo(i,j,k) num[a[i]]--;
continue;
}
ans=cnt[l+1][r-1];
fo(i,j,min(l*c,n)){
if (!num[a[i]]&&sum[a[i]][r-1]-sum[a[i]][l]==0) ans++;
num[a[i]]++;
}
fo(i,(r-1)*c+1,k){
if (!num[a[i]]&&sum[a[i]][r-1]-sum[a[i]][l]==0) ans++;
num[a[i]]++;
}
if (s!=t&&num[s]+sum[s][r-1]-sum[s][l]>0&&num[t]+sum[t][r-1]-sum[t][l]>0) ans--;
printf("%d\n",ans);
fo(i,j,min(l*c,n)) num[a[i]]--;
fo(i,(r-1)*c+1,k) num[a[i]]--;
}
}
int main(){
init();
if (n<=100&&m<=100) solve1();
solve2();
}