loj2564
SOL
有两个操作,1.树上区间数颜色 2.树上 n 2 n^2 n2版的操作1。
- 先放在序列上。
-
操作1:对于每一种颜色,只在其在区间的第一次出现位置统计。我们求出每一个位置对应的颜色的前一次出现位置(记为 p r e pre pre)(第一次就是 0 0 0),放在主席树中即可。
-
操作2,计算每一个点被多少区间访问到。
-对于 ( A , B ] (A,B] (A,B]之间的点 p p p,首先 p r e < A pre \lt A pre<A,左端点可以在 ( p r e , A ] (pre,A] (pre,A],右端点在 [ p , B ] [p,B] [p,B],总贡献: ( A − p r e ) ∗ ( B − p + 1 ) (A-pre)*(B-p+1) (A−pre)∗(B−p+1),由于对于 p r e pre pre有要求,主席树维护。-对于 [ 1 , A ] [1,A] [1,A]之前,左端点在 ( p r e , p ] (pre,p] (pre,p],右端点在 [ p , B ] [p,B] [p,B],或右端点在 ( p r e , p ] (pre,p] (pre,p],左端点在 [ p , A ] [p,A] [p,A],注意减去 ( p , p ) (p,p) (p,p)重复的一个贡献。为:
( p − p r e ) ∗ ( B − p + 1 ) + ( p − p r e ) ∗ ( A − p + 1 ) − 1 (p-pre)*(B-p+1)+(p-pre)*(A-p+1)-1 (p−pre)∗(B−p+1)+(p−pre)∗(A−p+1)−1可以发现对 p r e pre pre没有要求,前缀和维护即可。
- 转移到树上。
-
部分数据是随机造的,对于 p p p之前的是一条主链,长度不随机。但除开主链,副链从主链开始的深度期望 O ( l o g n ) O(logn) O(logn)。每种颜色的个数也是期望 O ( 1 ) O(1) O(1)
-
这提示我们可以将树上路径拆成长链和短链,短链直接暴力算,长链沿用序列的做法。
-
具体的,对于操作1,维护由点 i i i到根路径上每个颜色深度最大的点,记为 D c o l j , i D_{col_j,i} Dcolj,i。枚举短链上每个点尝试增加贡献,当且仅当 p r e i < L c a a n d D c o l i , 长 链 < L c a pre_i<Lca\ and\ D_{col_i,长链}<Lca prei<Lca and Dcoli,长链<Lca
-
对于操作2,先把 L c a Lca Lca到根重合的路径部分用序列的方法算出来,然后处理 ( L c a , A ] , ( L c a , B ] (Lca,A],(Lca,B] (Lca,A],(Lca,B]。为了不和第一种重复(lca本身不能取做端点了),我们先把 L c a Lca Lca拿出来,贡献单独算,为 ( A − L c a ) ∗ ( B − L c a ) (A-Lca)*(B-Lca) (A−Lca)∗(B−Lca)。
-对于长链上的点,满足 p r e i < L c a pre_i<Lca prei<Lca就可以产生 ( A − i + 1 ) ∗ ( B − L c a ) (A-i+1)*(B-Lca) (A−i+1)∗(B−Lca)的贡献,
对于短链,不紧在短链上的 p r e < L c a pre\lt Lca pre<Lca,在如果 D c o l i , 长 链 < L c a D_{col_i,长链}<Lca Dcoli,长链<Lca,贡献为 ( A − L c a ) ∗ ( B − i + 1 ) (A-Lca)*(B-i+1) (A−Lca)∗(B−i+1),否则找到 ( L c a , A ] (Lca,A] (Lca,A]中深度最浅的 c o l i col_i coli的点,贡献为 ( L o w e s t c o l i , 长 链 − L c a − 1 ) ∗ ( B − i + 1 ) (Lowest_{col{i,长链}}-Lca-1)*(B-i+1) (Lowestcoli,长链−Lca−1)∗(B−i+1)。怎么找?因为颜色出现数期望 O ( 1 ) O(1) O(1),暴力跳 p r e [ ] pre[\ ] pre[ ]即可。
CODE
#include<bits/stdc++.h>
#define pf printf
#define sf scanf
#define cs const
#define ll long long
#define db double
#define ri register int
using namespace std;
#define in red()
inline int red()
{
int data=0;int w=1; char ch=0;
ch=getchar();
while(ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
if(ch=='-') w=-1,ch=getchar();
while(ch>='0' && ch<='9') data=(data<<3)+(data<<1)+ch-'0',ch=getchar();
return data*w;
}
cs int N=2e5+10;
struct node{
ll s1,s2,s3,s4;
}tt;
inline node operator +(node a,node b){
a.s1+=b.s1;a.s2+=b.s2;a.s3+=b.s3;a.s4+=b.s4;
return a;
}
inline node operator -(node a,node b){
a.s1-=b.s1;a.s2-=b.s2;a.s3-=b.s3;a.s4-=b.s4;
return a;
}
//preμ??÷?ˉê÷
namespace P1{
node val[N*20];
int ch[N*20][2],tot=0,rt1[N];
#define lc(x) ch[x][0]
#define rc(x) ch[x][1]
inline int cpy(int x){++tot;val[tot]=val[x];ch[tot][0]=ch[x][0];ch[tot][1]=ch[x][1];return tot;}
inline void upt(int &p,int t,int l,int r,int k){
p=cpy(t);val[p]=val[p]+tt;
if(l==r)return;
int mid=(l+r)>>1;
if(k<=mid)upt(lc(p),lc(t),l,mid,k);
else upt(rc(p),rc(t),mid+1,r,k);
}
inline node qy(int p,int s,int l,int r,int k){
if(r<=k)return val[s]-val[p];
int mid=(l+r)>>1;
if(k<=mid)return qy(lc(p),lc(s),l,mid,k);
else return qy(lc(p),lc(s),l,mid,k)+qy(rc(p),rc(s),mid+1,r,k);
}
}
using P1::rt1;
//??é??a??±êμ??÷?ˉê÷
namespace P2{
int num[N*20],ch[N*20][2],tot=0,rt2[N];
#define lc(x) ch[x][0]
#define rc(x) ch[x][1]
inline int cpy(int x){++tot;ch[tot][0]=ch[x][0];ch[tot][1]=ch[x][1];return tot;}
inline void upt(int &p,int t,int l,int r,int k,int v){
p=cpy(t);
if(l==r)return num[p]=v,void();
int mid=(l+r)>>1;
if(k<=mid)upt(lc(p),lc(t),l,mid,k,v);
else upt(rc(p),rc(t),mid+1,r,k,v);
}
inline int qy(int p,int l,int r,int k){
if(l==r)return num[p];
int mid=(l+r)>>1;
if(k<=mid)return qy(lc(p),l,mid,k);
else return qy(rc(p),mid+1,r,k);
}
}
using P2::rt2;
typedef pair<ll,ll> pi;
#define fi first
#define se second
pi sum[N];
int las[N],pre[N],dep[N],col[N],top,fa[N],n,m;
int st[20][N<<1],Log[N<<1],dfn[N],tim=0;
inline int _min(int a,int b){return dfn[a]<dfn[b] ? a : b;}
inline int lca(int a,int b){
int fx=min(dfn[a],dfn[b]),fy=max(dfn[a],dfn[b]),k=Log[fy-fx+1];
return _min(st[k][fx],st[k][fy-(1<<k)+1]);
}
inline void init(){
for(ri i=2;i<=tim;++i)Log[i]=Log[i>>1]+1;
for(ri i=1;i<=Log[tim];++i){
for(ri j=1;j+(1<<i)-1<=tim;++j){
st[i][j]=_min(st[i-1][j],st[i-1][j+(1<<i-1)]);
}
}
}
vector<int> g[N];
inline void addedge(int u,int v){g[u].push_back(v);}
void dfs(int u){
dfn[u]=++tim;st[0][tim]=u;
pre[u]=las[col[u]];dep[u]=dep[fa[u]]+1;
las[col[u]]=u;
tt.s1=dep[u];tt.s2=dep[pre[u]];tt.s3=tt.s1*tt.s2;tt.s4=1;
P1::upt(rt1[u],rt1[fa[u]],0,n,dep[pre[u]]);
P2::upt(rt2[u],rt2[fa[u]],1,n,col[u],u);
sum[u].fi=sum[fa[u]].fi+dep[u]-dep[pre[u]];sum[u].se=sum[fa[u]].se+1ll*(dep[u]-dep[pre[u]])*(-2*dep[u]+2)-1ll;
for(ri i=g[u].size()-1;i>=0;--i){
int v=g[u][i];
fa[v]=u;
dfs(v);
st[0][++tim]=u;
}
las[col[u]]=pre[u];
}
inline int query1(int x,int y){
if(dep[x]<dep[y])swap(x,y);
int L=lca(x,y);
node tmp=P1::qy(rt1[fa[L]],rt1[x],0,n,dep[L]-1);
int res=tmp.s4;
while(y^L){
if(dep[pre[y]]<dep[L]&&dep[P2::qy(rt2[x],1,n,col[y])]<dep[L])++res;
y=fa[y];
}
return res;
}
inline ll Line(int A,int L){
ll res=0;
if(A^L){
node tmp=P1::qy(rt1[L],rt1[A],0,n,dep[L]-1);
res+=tmp.s4*(dep[A]+1)*dep[L]+tmp.s3-tmp.s2*(dep[A]+1)-tmp.s1*dep[L];
}
res+=sum[L].fi*(dep[A]+dep[L])+sum[L].se;
return res;
}
inline ll query2(int x,int y){
if(dep[x]<dep[y])swap(x,y);
int L=lca(x,y);
ll res=Line(x,L)+Line(y,L)-Line(L,L);
if(y==L)return res;
node tmp=P1::qy(rt1[L],rt1[x],0,n,dep[L]-1);
res+=tmp.s1*(dep[L]-dep[y])+tmp.s4*(dep[x]+1)*(dep[y]-dep[L])+1ll*(dep[x]-dep[L])*(dep[y]-dep[L]);
int B=dep[y];
while(y^L){
if(dep[pre[y]]<dep[L]){
int i=P2::qy(rt2[x],1,n,col[y]);
if(dep[i]<dep[L])res+=1ll*(dep[x]-dep[L])*(B-dep[y]+1);
else{
for(;dep[pre[i]]>=dep[L];i=pre[i]);
if(i^L)res+=1ll*(dep[i]-dep[L]-1)*(B-dep[y]+1);
}
}
y=fa[y];
}
return res;
}
unsigned int SA, SB, SC;
unsigned int rng61(){
SA ^= SA << 16;
SA ^= SA >> 5;
SA ^= SA << 1;
unsigned int t = SA;
SA = SB;
SB = SC;
SC ^= t ^ SA;
return SC;
}
void gen(){
int p;
scanf("%d%d%u%u%u", &n, &p, &SA, &SB, &SC);
for(int i = 2; i <= p; i++)
addedge(i - 1, i);
for(int i = p + 1; i <= n; i++)
addedge(rng61() % (i - 1) + 1, i);
for(int i = 1; i <= n; i++)
col[i] = rng61() % n + 1;
}
signed main (){
// freopen("data.in","r",stdin);
int T=in;
while(T--){
gen();
dfs(1);init();
m=in;
while(m--){
int op=in,x=in,y=in;
if(op==1)cout<<query1(x,y)<<'\n';
else cout<<query2(x,y)<<'\n';
}
for(ri i=1;i<=n;++i)g[i].clear();
P1::tot=P2::tot=0;tim=0;
}
return 0;
}