#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
ll mo,va[100001],nva[100001];
int n,m,r,fa[100001],top[100001],sz[100001],timer,dfn[100001],depth[100001],hson[100001];
vector <int> e[100001];
ll sum[400001],tre[400001],tag[400001];
//treee
int ls(int p){return (p << 1);}
int rs(int p){return (p << 1 | 1);}
void f(int p,int l,int r,ll k)
{
tag[p] = ( tag[p] + k ) % mo;
sum[p] = (sum[p] + k * (r - l + 1)) % mo;
}
void pushup(int p){sum[p] = (sum[ls(p)] + sum[rs(p)]) % mo;}
void pushdown(int p,int l,int r)
{
int mid = ( (l + r) >> 1 );
f(ls(p),l,mid,tag[p]);
f(rs(p),mid + 1,r,tag[p]);
tag[p] = 0;
}
void add(int gl,int gr,int nl,int nr,int p,ll k)
{
if(nl >= gl && nr <= gr)
{
sum[p] = (sum[p] + k * (nr - nl + 1)) % mo;
tag[p] = (tag[p] + k) % mo;
return ;
}
pushdown(p,nl,nr);
int mid = ((nl + nr) >> 1);
if(mid >= gl)add(gl,gr,nl,mid,ls(p),k);
if(mid < gr)add(gl,gr,mid + 1,nr,rs(p),k);
pushup(p);
}
ll query(int gl,int gr,int nl,int nr,int p)
{
if(nl >= gl && nr <= gr)
{
//cout << p << " gg";
return sum[p];
}
pushdown(p,nl,nr);
ll res = 0;
int mid = ((nl + nr) >> 1);
if(mid >= gl)res += query(gl,gr,nl,mid,ls(p));
if(mid < gr)res += query(gl,gr,mid + 1,nr,rs(p));
return (res % mo);
}
void build(int p,int l,int r)
{
if(r == l)
{
sum[p] = nva[l] % mo;
//if(l == 3)cout << "ppp" << p << " " << sum[p] << "ppp";
return ;
}
int mid = ( (l + r) >> 1 );
build(ls(p),l,mid);
build(rs(p),mid + 1,r);
pushup(p);
}
//treee
void dfs1(int x,int dep)
{
depth[x] = dep;
sz[x] = 1;
int maxx = 0,maxson = 0;
for (int i = 0; i < e[x].size(); i++)
{
int v = e[x][i];
if(v != fa[x])
{
fa[v] = x;
dfs1(v,dep + 1);
sz[x] += sz[v];
if(sz[v] > maxx)
{
maxx = sz[v];
maxson = v;
}
}
}
hson[x] = maxson;
}
void dfs2(int x,int t)
{
dfn[x] = ++timer;
nva[timer] = va[x];
top[x] = t;
if(!hson[x])return ;
dfs2(hson[x],t);
for (int i = 0; i < e[x].size(); i++)
{
int v = e[x][i];
if(v == hson[x] || v == fa[x])continue;
dfs2(v,v);
}
}
void addnum(int x,int y,ll z)
{
z %= mo;
while(top[x] != top[y])
{
if(depth[ top[x] ] < depth[ top[y] ])swap(x,y);//xdepth深
add(dfn[top[x]],dfn[x],1,n,1,z);
x = fa[top[x]];
}
if(depth[ x ] < depth[ y ])swap(x,y);
add(dfn[y],dfn[x],1,n,1,z);
}
void addtre(int x,ll z)
{
z %= mo;
//cout << dfn[x]<< " " << dfn[x] + sz[x] - 1;
add(dfn[x],dfn[x] + sz[x] - 1,1,n,1,z);
}
ll quetre(int x)
{
return query(dfn[x],dfn[x] + sz[x] - 1,1,n,1);
}
ll quesum(int x,int y)
{
ll res = 0;
while(top[x] != top[y])
{
if(depth[top[x]] < depth[top[y]])swap(x,y);//x is deeper
res = (res + query(dfn[top[x]],dfn[x],1,n,1)) % mo;
x = fa[top[x]];
}
if(depth[x] < depth[y])swap(x,y);
res = (res + query(dfn[y],dfn[x],1,n,1)) % mo;
return res;
}
int main()
{
scanf(" %d %d %d %lld",&n,&m,&r,&mo);
for (int i = 1; i <= n; i++)
{
scanf(" %lld",&va[i]);
}
int a,b;
for (int i = 1; i < n; i++)
{
cin >> a >> b;
e[a].push_back(b),e[b].push_back(a);
}
dfs1(r,1);
dfs2(r,r);
build(1,1,n);
//cout << query(3,3,1,n,1) << " ppp";
while(m--)
{
int c;
ll d;
scanf(" %d",&a);
if(a == 1)
{
scanf(" %d %d %lld",&b,&c,&d);
addnum(b,c,d);
}
else if(a == 2)
{
scanf(" %d %d",&b,&c);
printf("%lld\n",quesum(b,c) );
}
else if(a == 3)
{
scanf(" %d %lld",&b,&d);
addtre(b,d);
}
else
{
scanf(" %d",&b);
printf("%lld\n",quetre(b));
}
//cout << "ppp" << query(3,3,1,n,1) << " ";
}
return 0;
}
简单来说:
先按根节点跑一边dfs序
分出轻重链和轻重边,树即可转化为序列
可以证明任意一条路径上最多包含logn条链(反证)
主要说一下最短路径:
标记top,不断往上跳,注意要让top小的条,while出去还要跳,跳的同时处理
犯的错:
把r和l搞反了(典型的左右不分 )
dfn[]和fa[]有些地方忘加了
query()写成了(),漏掉了函数名,我也是服了,隔一段时间就会犯这样的错,得出的答案还很诡异,真的很难看出是哪错了。。。