题目大意
给定一棵
n
个节点的树,每个节点有颜色
∙i≠j
∙i∈path(x,y)
∙j∈path(u,v)
∙ci=cj
1≤n≤105,1≤q≤5×104,1≤ci≤109
题目分析
这题一看不可做。
我们来思考一下这题有哪些方法可以做。
先不考虑
i≠j
的限制。
Algorithm 1
分别处理每一种颜色。枚举颜色,然后计算每个点到根路径上的同色点的个数。这种颜色对于一个询问的贡献就是两条路径点数相乘即可。
令颜色种数为
C
,那么时间复杂度是
Algorithm 2
对于单种颜色,令
sum(x)
到根路径上的同色点个数,询问形如
将询问差分拆成16个。于是每个询问变成两条到根路径同色点对个数。
怎么做呢?将所有询问挂到对应的两个点上,DFS这棵树。每到一个点,先枚举每一个同色的点,那么显然如果一个询问的两个点分别在这两个点的子树内,那这对点就会对询问有贡献。因此我们对这个同色点的子树打一个 +1 标记。然后我们处理挂在这个点上的询问,如果询问的另一个点已经访问过(为了不算重),那么其对答案的贡献就是这个点的标记和(注意符号由差分询问时决定)。然后在退出节点的时候我们就撤销所有这个节点产生的标记。
注意为了处理特殊情况,一个点应该在访问完所有的儿子之后才标记为访问过。
至于处理标记,使用树状数组就可以了。
细节挺多的,请自己思考吧。
令 A 表示一种颜色最多的出现次数,这种算法的时间复杂度是
Algorithm 3
我们发现上面两种方法都过不了,怎么办呢?
我们可以使用阈值均衡两种算法!
设置阈值
B
,如果一种颜色出现次数大于
时间复杂度是
O(⌊nB⌋n+Bnlog2n)
。阈值瞎选一个就好了,我选的是
n√
。
但是这样我们还有一个问题没有考虑,就是
i≠j
怎么处理?
可以发现,对于一个询问的两条路径,算重的显然是两条路径交集的长度。那么我们可以瞎分类讨论一波(你讨厌分类讨论可以直接打树剖),从答案中减去就好了。
此题完美解决!
Algorithm n?
貌似这题在hackerrank上的标解是树链剖分(捂脸)。
貌似国外有人用树上莫队过了?怎么做,不会~
代码实现
好久没有打过这么复杂的代码了……
#include <algorithm>
#include <iostream>
#include <cstdio>
#include <cctype>
#include <cmath>
using namespace std;
typedef long long LL;
int read()
{
int x=0,f=1;
char ch=getchar();
while (!isdigit(ch)) f=ch=='-'?-1:f,ch=getchar();
while (isdigit(ch)) x=x*10+ch-'0',ch=getchar();
return x*f;
}
int buf[30];
void write(LL x)
{
if (x<0) putchar('-'),x=-x;
for (;x;x/=10) buf[++buf[0]]=x%10;
if (!buf[0]) buf[++buf[0]]=0;
for (;buf[0];) putchar('0'+buf[buf[0]--]);
}
const int N=100050;
const int Q=50050;
const int M=N<<1;
const int EL=N<<1;
const int LGEL=18;
const int T=Q<<1;
struct query
{
int qid,x,sign;
}qs[T<<4];
struct node
{
int pid,col;
}ns[N];
bool operator<(node x,node y){return x.col<y.col;}
int last[N],fa[N],pos[N],ptr[N],deep[N],c[N],v[N],qlst[N],DFN[N],size[N];
int n,tot,el,lgel,idx,thr,qtot,ctp,q;
int euler[EL],LOG[EL];
bool mark[N],vis[N];
int tov[M],nxt[M];
int rmq[EL][LGEL];
int qnxt[T<<4];
int qy[Q][4];
LL ans[Q];
int lowbit(int x){return x&-x;}
struct Fenwick_tree
{
int num[N];
int query(int x)
{
int ret=0;
for (;x;x-=lowbit(x)) ret+=num[x];
return ret;
}
void modify(int x,int delta){for (;x<=n;x+=lowbit(x)) num[x]+=delta;}
}t;
void insert(int x,int y){tov[++tot]=y,nxt[tot]=last[x],last[x]=tot;}
void hang(int x,int y,int id,int sign){qs[++qtot].qid=id,qs[qtot].x=y,qs[qtot].sign=sign,qnxt[qtot]=qlst[x],qlst[x]=qtot;}
void build(int x,int y,int id,int sign)
{
hang(x,y,id,sign);
if (x!=y) hang(y,x,id,sign);
}
void dfs(int x)
{
size[rmq[pos[euler[++el]=x]=el][0]=x]=1,DFN[x]=++idx;
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa[x])
fa[y]=x,deep[y]=deep[x]+1,dfs(y),euler[++el]=x,rmq[el][0]=x,size[x]+=size[y];
}
void pre_rmq()
{
lgel=trunc(log(el)/log(2));
for (int j=1;j<=lgel;j++)
for (int i=1;i+(1<<j)-1<=el;i++)
rmq[i][j]=deep[rmq[i][j-1]]<deep[rmq[i+(1<<j-1)][j-1]]?rmq[i][j-1]:rmq[i+(1<<j-1)][j-1];
LOG[1]=0;
for (int i=2;i<=el;i++) LOG[i]=LOG[i-1]+(1<<LOG[i-1]+1==i);
}
int get_rmq(int l,int r)
{
int lgr=LOG[r-l+1];
return deep[rmq[l][lgr]]<deep[rmq[r-(1<<lgr)+1][lgr]]?rmq[l][lgr]:rmq[r-(1<<lgr)+1][lgr];
}
int lca(int x,int y)
{
if ((x=pos[x])>(y=pos[y])) swap(x,y);
return get_rmq(x,y);
}
void sum(int x)
{
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa[x]) v[y]+=v[x],sum(y);
}
int getsum(int x,int y)
{
int z=lca(x,y);
return v[x]+v[y]-v[z]-v[fa[z]];
}
void pre_color()
{
sort(ns+1,ns+1+n);
ns[0].col=0,ctp=0;
for (int i=1;i<=n;i++)
{
if (ns[i].col!=ns[i-1].col) ptr[++ctp]=i;
c[ns[i].pid]=ctp;
}
ptr[ctp+1]=n+1;
for (int i=1;i<=ctp;i++)
if (ptr[i+1]-ptr[i]<=thr) mark[i]=1;
else
{
for (int j=1;j<=n;j++) v[j]=0;
for (int j=ptr[i];j<ptr[i+1];j++) v[ns[j].pid]++;
sum(1);
for (int j=1;j<=q;j++) ans[j]+=1ll*getsum(qy[j][0],qy[j][1])*getsum(qy[j][2],qy[j][3]);
}
}
void solve(int x)
{
if (mark[c[x]]) for (int i=ptr[c[x]],y;i<ptr[c[x]+1];i++) t.modify(DFN[y=ns[i].pid],1),t.modify(DFN[y]+size[y],-1);
for (int i=last[x],y;i;i=nxt[i])
if ((y=tov[i])!=fa[x]) solve(y);
vis[x]=1;
for (int i=qlst[x];i;i=qnxt[i])
if (vis[qs[i].x]) ans[qs[i].qid]+=t.query(DFN[qs[i].x])*qs[i].sign;
if (mark[c[x]]) for (int i=ptr[c[x]],y;i<ptr[c[x]+1];i++) t.modify(DFN[y=ns[i].pid],-1),t.modify(DFN[y]+size[y],1);
}
void process()
{
for (int i=1;i<=q;i++)
{
int x=qy[i][0],y=qy[i][1],z=lca(x,y),u=qy[i][2],v=qy[i][3],w=lca(u,v),tmp=0;
if (DFN[w]<DFN[z]&&DFN[z]<=DFN[w]+size[w]-1)
{
if (lca(u,z)==z) tmp=max(deep[lca(x,u)]-deep[z],deep[lca(y,u)]-deep[z])+1;
else if (lca(v,z)==z) tmp=max(deep[lca(x,v)]-deep[z],deep[lca(y,v)]-deep[z])+1;
}
else
if (DFN[z]<=DFN[w]&&DFN[w]<=DFN[z]+size[z]-1)
if (w==z)
{
int a=lca(x,u),b=lca(x,v);
tmp=deep[a]-deep[z]+deep[b]-deep[z];
a=lca(y,u),b=lca(y,v);
tmp+=deep[a]-deep[z]+deep[b]-deep[z];
tmp++;
}
else
{
int a=lca(x,u),b=lca(x,v);
if (a==w||b==w) tmp=max(deep[a]-deep[b],deep[b]-deep[a])+1;
a=lca(y,u),b=lca(y,v);
if (a==w||b==w) tmp=max(deep[a]-deep[b],deep[b]-deep[a])+1;
}
ans[i]-=tmp;
}
}
int main()
{
freopen("count.in","r",stdin),freopen("count.out","w",stdout);
n=read(),q=read(),thr=trunc(sqrt(n));
for (int i=1;i<=n;i++) ns[i].col=c[i]=read(),ns[i].pid=i;
for (int i=1,x,y;i<n;i++)
{
x=read(),y=read();
insert(x,y),insert(y,x);
}
fa[1]=0,deep[1]=1,dfs(1),pre_rmq();
for (int i=1;i<=q;i++)
{
for (int j=0;j<4;j++) qy[i][j]=read();
int x=qy[i][0],y=qy[i][1],z=lca(x,y),u=qy[i][2],v=qy[i][3],w=lca(u,v);
build(x,u,i,1),build(x,w,i,-1),build(x,v,i,1);
if (fa[w]) build(x,fa[w],i,-1);
if (fa[z])
{
build(fa[z],u,i,-1),build(fa[z],w,i,1),build(fa[z],v,i,-1);
if (fa[w]) build(fa[z],fa[w],i,1);
}
build(z,u,i,-1),build(z,w,i,1),build(z,v,i,-1);
if (fa[w]) build(z,fa[w],i,1);
build(y,u,i,1),build(y,w,i,-1),build(y,v,i,1);
if (fa[w]) build(y,fa[w],i,-1);
}
pre_color(),solve(1),process();
for (int i=1;i<=q;i++) write(ans[i]),putchar('\n');
fclose(stdin),fclose(stdout);
return 0;
}