目的
老是打错,现在系统搞一遍。
例题
分析
经典的主席树,要找LCA。
LCA
首先每访问一次一个点,dfn++,设le[x]为点x最小dfn,ref[y]为dfn=y的时候在哪个点上。f[y][x]为dfn为x开始,长度为
1<<y
的一段最小dis的点的编号。
注意搞f的时候每次要
i=1 dfn−(1<<(j))
,爆掉可能有问题,然后
1<<j
要打括号,+-是优先于<<的。这东西调了我30min。
代码
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
using namespace std;
#define fo(i,j,k) for(i=j;i<=k;i++)
#define fd(i,j,k) for(i=j;i>=k;i--)
typedef long long ll;
const int N=600005;
struct rec1
{
int val,l,r;
}tr[N*10];
struct rec
{
int x,y;
}B[N];
int st[N],f[20][N],dis[N],le[N],ri[N],dfn,ref[N],rev[N],t1,t2,x,y,z,c1,c2,c3,zz,n,m,q,i,j,a[N],Log[N],lastans,lc,fa[N];
int tt,b[N],next[N],first[N];
int k1,k2,k3;
void cr(int x,int y)
{
tt++;
b[tt]=y;
next[tt]=first[x];
first[x]=tt;
}
bool cmp(rec a,rec b)
{
return a.x<b.x;
}
void lsh()
{
fo(i,1,n)
{
B[i].x=a[i];
B[i].y=i;
}
sort(B+1,B+1+n,cmp);
t1=0;
B[0].x=-1;
fo(i,1,n)
{
if (B[i].x!=B[i-1].x)
rev[++t1]=B[i].x;
a[B[i].y]=t1;
}
rev[t1+1]=1e9;
rev[0]=0;
}
int disperse(int z)
{
int l=1,r=t1+1,mid=0;
while (l<r)
{
mid=(r+l)/2;
if (rev[mid]>z) r=mid;
else l=mid+1;
if (rev[mid]==z) return mid;
}
return l;
}
void change(int x_1,int x_2,int l,int r,int pos)
{
int m=(l+r)/2;
if (l==r)
{
tr[x_1].val=tr[x_2].val+1;
return ;
}
if (m>=pos)
{
tr[x_1].l=++t2;
tr[x_1].r=tr[x_2].r;
change(t2,tr[x_2].l,l,m,pos);
}
else
{
tr[x_1].r=++t2;
tr[x_1].l=tr[x_2].l;
change(t2,tr[x_2].r,m+1,r,pos);
}
tr[x_1].val=tr[tr[x_1].l].val+tr[tr[x_1].r].val;
}
int get(int x,int l,int r,int i,int j)
{
if (i>j) return 0;
int m=(l+r)/2;
if (l==i&&r==j)
return tr[x].val;
if (m>=j)
return get(tr[x].l,l,m,i,j);
else if (m<i)
return get(tr[x].r,m+1,r,i,j);
else return get(tr[x].l,l,m,i,m)+get(tr[x].r,m+1,r,m+1,j);
}
void dfs(int x,int y)
{
fa[x]=y;
st[x]=++t2;
change(st[x],st[y],1,t1,a[x]);
dis[x]=dis[y]+1;
dfn++;
ref[dfn]=x;
le[x]=dfn;
for(int p=first[x];p;p=next[p])
if (b[p]!=y)
{
dfs(b[p],x);
dfn++;
ref[dfn]=x;
}
ri[x]=dfn;
}
int mi(int x,int y)
{
if (dis[x]<dis[y]||!y) return x;
return y;
}
void make_ST()
{
fo(i,1,dfn) f[0][i]=ref[i];
fo(j,1,Log[dfn])
fo(i,1,dfn)
{
if (i+(1<<(j-1))>dfn) f[j][i]=f[j-1][i];else
f[j][i]=mi(f[j-1][i],f[j-1][i+(1<<(j-1))]);
}
}
int lca(int x,int y)
{
if (x==y) return x;
x=le[x];
y=le[y];
if (x>y) swap(x,y);
//if (ri[x]>ri[y]) return x;
int z=Log[y-x+1];
return mi(f[z][x],f[z][y-(1<<z)+1]);
}
int main()
{
freopen("a.in","r",stdin);
freopen("a.out","w",stdout);
scanf("%d %d %d",&n,&m,&q);
fo(i,1,n) scanf("%d",a+i);
lsh();//reverse
fo(i,1,n-1) scanf("%d %d",&x,&y),cr(x,y),cr(y,x);
dfs(1,0);
fo(i,1,n*2) Log[i]=trunc(log(i)/log(2));
make_ST();
fo(i,1,q)
{
scanf("%d %d %d",&x,&y,&z);
x^=lastans;
y^=lastans;
z^=lastans;
zz=disperse(z);
lc=lca(x,y);
c1=get(st[x],1,t1,1,zz-1)+get(st[y],1,t1,1,zz-1)-2*get(st[lc],1,t1,1,zz-1);
k1=get(st[x],1,t1,1,zz-1);
k2=get(st[y],1,t1,1,zz-1);
k3=get(st[lc],1,t1,1,zz-1);
if (rev[a[lc]]<z) c1++;
c2=0;
if (rev[zz]==z)
{
c2=get(st[x],1,t1,zz,zz)+get(st[y],1,t1,zz,zz)-2*get(st[lc],1,t1,zz,zz);
if (rev[a[lc]]==z) c2++;
}
c3=dis[x]+dis[y]-2*dis[lc]+1-c1-c2;
printf("%d %d %d\n",c1,c2,c3);
lastans=c1^c2^c3;
}
return 0;
}