【题目大意】
n个节点的一棵树,有三种操作。
1:将x到y的路径上的所有点的点权加上delta
2:询问x到y的答案。答案的计算为:对于路径上的点i,设它到y的距离为s,则i的贡献为1加到s。
3:将这棵树恢复到第x次1操作之后的版本。
操作数为m,强制在线。
【20%】n,m<=1000
对于1、2操作,暴力遍历x到y的路径去修改或求答案;并且我们存下每一次1操作后的版本,然后对于3操作,O(n)修改整棵树。
时间复杂度O(nm)
【20% 树的形态为一条链,无操作3】n<=30000, m<=50000
考虑询问。我们将x到y的路径分为x到lca和lca到y两部分。
对于第一部分的点i,设它到y的距离为
s
,则
对于第二部分的点i,
s=deep[y]−deep[i]
,设
t=deep[y]
,则贡献为
a[i]∗deep[i]2−a[i]∗deep[i]∗(2∗t+1)+a[i]∗(t+t2)
。
可以发现,对于每个点i我们只需维护
a[i]∗deep[i]2
、
a[i]∗deep[i]
和
a[i]
即可。这个就是线段树的基本功能了。(提醒:上述式子记得加上/2)
时间复杂度O(m log n)
【40% 树的形态为一条链,有操作3】
操作3显然就是要让我们把普通线段树加上可持久化。
时间复杂度O(m logn)
可以感觉到代码量已经膨胀了。
【100%】n,m<=10^5,操作2<=50000
序列放在树上,通过树链剖分,转化为上述问题。
时间复杂度O(m log^2)
#include<cstdio>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
using namespace std;
typedef long long LL;
const int maxn=(1e5)+5, MX=18, maxtr=7000004;
const LL mo=20160501, er=10080251;
struct TRTree{
LL d2,d1,ad2,ad1,a,bz,nowdelta;
};
int n,m;
LL a[maxn],ans,ansd1[2],ansa[2],delta;
int tot,go[2*maxn],next[2*maxn],f1[maxn];
void ins(int x,int y)
{
go[++tot]=y;
next[tot]=f1[x];
f1[x]=tot;
}
int d[maxn],fa[maxn],deep[maxn],size[maxn],Hson[maxn];
void bfs_size()
{
d[1]=1;
deep[1]=1;
for(int i=1, j=1; i<=j; i++)
{
for(int p=f1[d[i]]; p; p=next[p]) if (!deep[go[p]])
{
d[++j]=go[p];
deep[go[p]]=deep[d[i]]+1;
fa[go[p]]=d[i];
}
}
fd(i,n,1)
{
size[fa[d[i]]]+= (++size[d[i]]);
if (size[d[i]] > size[Hson[ fa[d[i]] ]]) Hson[fa[d[i]]]=d[i];
}
}
int sum,Tbh[maxn],Lbh[maxn],top[maxn];
void New(int i,int sta)
{
Lbh[i]=++sum;
Tbh[sum]=i;
top[i]=sta;
}
void build()
{
fo(i,1,n) if (!Lbh[d[i]])
{
for(int j=d[i]; j; j=Hson[j]) New(j,d[i]);
}
}
int now,root[maxn],trsum,lasttrsum,son[maxtr][2],tm;
TRTree tr[maxtr];
void tr_js(int k,int l,int r)
{
if (l==r)
{
int x=Tbh[l];
tr[k].d1=(LL)deep[x];
tr[k].d2=(LL)deep[x]*deep[x];
tr[k].a=a[x];
tr[k].ad1=a[x]*tr[k].d1;
tr[k].ad2=a[x]*tr[k].d2;
return;
}
int t1=(l+r)>>1;
tr_js(son[k][0]=++trsum,l,t1), tr_js(son[k][1]=++trsum,t1+1,r);
int ls=son[k][0], rs=son[k][1];
tr[k].ad2=(tr[ls].ad2+tr[rs].ad2)%mo;
tr[k].ad1=(tr[ls].ad1+tr[rs].ad1)%mo;
tr[k].a=(tr[ls].a+tr[rs].a)%mo;
tr[k].d2=(tr[ls].d2+tr[rs].d2)%mo;
tr[k].d1=(tr[ls].d1+tr[rs].d1)%mo;
}
void tr_xg(int k,int last,int l,int r,int x,int y)
{
if (k==trsum) tr[k]=tr[last];
if (l==x && r==y)
{
son[k][0]=son[last][0], son[k][1]=son[last][1];
tr[k].nowdelta=(tr[k].nowdelta+delta)%mo;
tr[k].ad2=(tr[k].ad2+delta*tr[k].d2)%mo;
tr[k].ad1=(tr[k].ad1+delta*tr[k].d1)%mo;
tr[k].a=(tr[k].a+delta*(r-l+1))%mo;
tr[k].bz=(tr[k].bz+delta)%mo;
return;
}
int t1=(l+r)>>1;
if (y<=t1)
{
if (son[k][1]<=lasttrsum) son[k][1]=son[last][1];
if (son[k][0]<=lasttrsum) son[k][0]=++trsum;
tr_xg(son[k][0],son[last][0],l,t1,x,y);
} else if (x>t1)
{
if (son[k][0]<=lasttrsum) son[k][0]=son[last][0];
if (son[k][1]<=lasttrsum) son[k][1]=++trsum;
tr_xg(son[k][1],son[last][1],t1+1,r,x,y);
} else
{
if (son[k][0]<=lasttrsum) son[k][0]=++trsum;
tr_xg(son[k][0],son[last][0],l,t1,x,t1);
if (son[k][1]<=lasttrsum) son[k][1]=++trsum;
tr_xg(son[k][1],son[last][1],t1+1,r,t1+1,y);
}
int ls=son[k][0], rs=son[k][1]; LL BZ=tr[k].bz;
tr[k].ad2=(tr[ls].ad2+tr[rs].ad2+ BZ*tr[k].d2%mo )%mo;
tr[k].ad1=(tr[ls].ad1+tr[rs].ad1+ BZ*tr[k].d1%mo )%mo;
tr[k].a=(tr[ls].a+tr[rs].a+ BZ*(r-l+1)%mo )%mo;
}
void upans(int k,int ty,LL z,LL len)
{
ans=(ans+tr[k].ad2+ z*tr[k].d2%mo )%mo;
ansd1[ty]=(ansd1[ty]+tr[k].ad1+ z*tr[k].d1%mo )%mo;
ansa[ty]=(ansa[ty]+tr[k].a+ z*len%mo )%mo;
}
void tr_cx(int k,int l,int r,int x,int y,bool ty,LL tbz)
{
if (!k) return;
tbz+=tr[k].bz;
if (l==x && r==y)
{
tbz-=tr[k].nowdelta;
upans(k,ty,tbz,r-l+1);
return;
}
int t1=(l+r)>>1;
if (y<=t1) tr_cx(son[k][0],l,t1,x,y,ty,tbz);
else if (x>t1) tr_cx(son[k][1],t1+1,r,x,y,ty,tbz);
else tr_cx(son[k][0],l,t1,x,t1,ty,tbz), tr_cx(son[k][1],t1+1,r,t1+1,y,ty,tbz);
}
void jump_xg(int x,int y)
{
while (top[x]!=top[y])
{
if (deep[top[x]]<deep[top[y]]) swap(x,y);
tr_xg(root[tm],root[now],1,n,Lbh[top[x]],Lbh[x]);
x=fa[top[x]];
}
if (deep[x]<deep[y]) swap(x,y);
tr_xg(root[tm],root[now],1,n,Lbh[y],Lbh[x]);
}
int jump_cx(int x,int y)
{
while (top[x]!=top[y])
{
if (deep[top[x]]>deep[top[y]])
{
tr_cx(root[now],1,n,Lbh[top[x]],Lbh[x],0,0);
x=fa[top[x]];
} else
{
tr_cx(root[now],1,n,Lbh[top[y]],Lbh[y],1,0);
y=fa[top[y]];
}
}
if (deep[x]>deep[y])
{
tr_cx(root[now],1,n,Lbh[y],Lbh[x],0,0);
return y;
} else
{
tr_cx(root[now],1,n,Lbh[x],Lbh[y],1,0);
return x;
}
}
int main()
{
scanf("%d %d",&n,&m);
fo(i,1,n-1)
{
int u,v;
scanf("%d %d",&u,&v);
ins(u,v), ins(v,u);
}
bfs_size();
build();
fo(i,1,n) scanf("%lld",&a[i]);
root[0]=trsum=1, tr_js(1,1,n);
while (m--)
{
int ty,x,y;
scanf("%d %d",&ty,&x);
if (ty==1)
{
scanf("%d %lld",&y,&delta);
x^=ans, y^=ans;
lasttrsum=trsum;
root[++tm]=++trsum;
jump_xg(x,y);
now=tm;
} else if (ty==2)
{
scanf("%d",&y);
x^=ans, y^=ans;
ans=0;
ansd1[0]=ansd1[1]=ansa[0]=ansa[1]=0;
int lca=jump_cx(x,y);
LL t=deep[y]-2*deep[lca];
ans=(ans+ ansd1[0]*(2*t+1)%mo +ansa[0]*(t+t*t%mo)%mo) %mo;
t=deep[y];
ans=(ans- ansd1[1]*(2*t+1)%mo +ansa[1]*(t+t*t%mo)%mo) %mo;
ans=(ans+mo)%mo;
ans=ans*er%mo;
printf("%lld\n",ans);
} else
{
x^=ans;
now=x;
}
}
}