题目
https://www.luogu.org/problem/show?pid=3384
题解
辣鸡题目毁我青春!
以前写线段树,指针从来不赋初值,偏偏这道题乖张,本来应该1A的题目我提交了29遍!我的AC率,我的时间!
呃,,吐槽完了。。还是自己习惯不好
轻重边剖分之后每个点的tid其实就是它的dfs序。。tid数组顺着输出其实就是树的先序遍历(允许我乱用概念吧),那么以一个节点为根节点的子树就是刚进dfs时候的时间戳到退出这层dfs时的时间戳之内的节点。
代码
//树链剖分+dfs序
#include <cstdio>
#include <algorithm>
#define maxn 200010
using namespace std;
int N, M, R, P, l[maxn], r[maxn], tim, fa[maxn], size[maxn], son[maxn], tid[maxn],
head[maxn], next[maxn], to[maxn], val[maxn], w[maxn], tot, top[maxn], deep[maxn],
debug;
struct segtree
{
int l, r, sum, d;
segtree *lch, *rch;
segtree(){l=r=d=sum=0;lch=rch=0;}
}*root;
void adde(int a, int b){to[++tot]=b;next[tot]=head[a];head[a]=tot;}
void pushdown(segtree *p)
{
p->sum=(p->sum+(p->r-p->l+1)*p->d)%P;
if(p->lch)p->lch->d+=p->d,p->rch->d+=p->d;
p->d=0;
}
void update(segtree *p)
{
if(p->lch==0)return;
pushdown(p->lch),pushdown(p->rch);
p->sum=(p->lch->sum+p->rch->sum)%P;
}
void segadd(segtree *p, int l, int r, int d)
{
pushdown(p);
int mid=(p->l+p->r)>>1;
if(l<=p->l and r>=p->r){p->d+=d;return;}
if(l<=mid)segadd(p->lch,l,r,d);
if(r>mid)segadd(p->rch,l,r,d);
update(p);
}
int segsum(segtree *p, int l, int r)
{
pushdown(p);
int mid=(p->l+p->r)>>1, ans=0;
if(l<=p->l and r>=p->r){return p->sum;}
if(l<=mid)ans=(ans+segsum(p->lch,l,r))%P;
if(r>mid)ans=(ans+segsum(p->rch,l,r))%P;
return ans;
}
void build(segtree *p, int l, int r)
{
p->l=l,p->r=r;
if(l==r){p->sum=w[l];return;}
int mid=(l+r)>>1;
build(p->lch=new segtree,l,mid);
build(p->rch=new segtree,mid+1,r);
update(p);
}
void dfs1(int pos)
{
int p;
size[pos]=1;
for(p=head[pos];p;p=next[p])
{
if(to[p]==fa[pos])continue;
fa[to[p]]=pos;
deep[to[p]]=deep[pos]+1;
dfs1(to[p]);
if(size[to[p]]>size[son[pos]])son[pos]=to[p];
size[pos]+=size[to[p]];
}
}
void dfs2(int pos, int tp)
{
int p;
top[pos]=tp;
tid[pos]=++tim;
l[pos]=tim;
if(son[pos])dfs2(son[pos],tp);
for(p=head[pos];p;p=next[p])
{
if(to[p]==fa[pos] or to[p]==son[pos])continue;
dfs2(to[p],to[p]);
}
r[pos]=tim;
}
void init()
{
int i, a, b;
scanf("%d%d%d%d",&N,&M,&R,&P);
for(i=1;i<=N;i++)scanf("%d",val+i);
for(i=1;i<N;i++)scanf("%d%d",&a,&b),adde(a,b),adde(b,a);
dfs1(R);
dfs2(R,R);
for(i=1;i<=N;i++)w[tid[i]]=val[i];
build(root=new segtree,1,tim);
}
void add(int a, int b, int d)
{
int ta=top[a], tb=top[b];
while(ta!=tb)
{
if(deep[ta]<deep[tb])swap(a,b),swap(ta,tb);
segadd(root,tid[ta],tid[a],d);
a=fa[ta];ta=top[a];
}
if(deep[a]>deep[b])swap(a,b);
segadd(root,tid[a],tid[b],d);
}
int sum(int a, int b)
{
int ta=top[a], tb=top[b], ans=0;
while(ta!=tb)
{
if(deep[ta]<deep[tb])swap(a,b),swap(ta,tb);
ans+=segsum(root,tid[ta],tid[a]);ans%=P;
a=fa[ta];ta=top[a];
}
if(deep[a]>deep[b])swap(a,b);
ans+=segsum(root,tid[a],tid[b]);ans%=P;
return ans;
}
int main()
{
int i, x, y, z, type;
segtree *p=new segtree;
init();
for(i=1;i<=M;i++)
{
scanf("%d",&type);
if(type==1)scanf("%d%d%d",&x,&y,&z),add(x,y,z);
if(type==2)scanf("%d%d",&x,&y),printf("%d\n",sum(x,y));
if(type==3)scanf("%d%d",&x,&z),segadd(root,l[x],r[x],z);
if(type==4)scanf("%d",&x),printf("%d\n",segsum(root,l[x],r[x]));
}
return 0;
}