思路:
树链剖分+树状数组
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N=100005;
int son[N];//son[i]表示i的重儿子
int size[N];//size[i]表示i为根包含自己的字数节点个数
int f[N];//f[i]表示i的父亲
int dep[N];//dep[i]表示i的深度
vector<int> G[N];
void addedge(int u,int v)
{
G[u].push_back(v);
G[v].push_back(u);
}
void dfs1(int cur,int fa)
{
size[cur]=1;
for(int i=0;i<G[cur].size();i++)
{
int to=G[cur][i];
if(to==fa) continue;
dep[to]=dep[cur]+1;
f[to]=cur;
dfs1(to,cur);
size[cur]+=size[to];
if(size[to]>size[son[cur]]) son[cur]=to;
}
}
int top[N];//top[i]表示节点i所在链的顶端
int id[N];//id[i]表示i的新编号
ll val[N];//val[i]表示新编号的点权
ll w[N];//w[i]表示原编号的点权
int cnt;
void dfs2(int cur,int t)
{
id[cur]=++cnt;
val[cnt]=w[cur];
top[cur]=t;
if(son[cur]) dfs2(son[cur],t);
for(int i=0;i<G[cur].size();i++)
if(G[cur][i]!=f[cur]&&G[cur][i]!=son[cur])
dfs2(G[cur][i],G[cur][i]);
}
ll a[N],sum1[N], sum2[N];
int n,m;
ll lowbit(ll x)
{
return x & (-x);
}
void updata(int i, ll k)
{
int x = i;
while (i <= n)
{
sum1[i] += k;
sum2[i] += k * (x - 1);
i += lowbit(i);
}
}
ll getsum(int i)
{
ll res = 0, x = i;
while (i > 0)
{
res += x * sum1[i] - sum2[i];
i -= lowbit(i);
}
return res;
}
void update_val(int x,int y,ll z)//修改点x到y的路径上点权
{
int fx = top[x], fy = top[y];
while (fx != fy)
{
if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
updata(id[fx], z);
updata(id[x]+1, -z);//每往上跳一次,就修改以次经过路径上的值,因为DFS序中id[fx] < id[x],所以是区间[id[fx],id[x]];
x = f[fx], fx = top[x];
}
if (id[x] > id[y]) swap(x, y);//要保证区间是从小到大的
updata(id[x], z);
updata(id[y]+1, -z);
}
ll query_val(int x, int y)
{
ll ans = 0;
int fx = top[x], fy = top[y];
while (fx != fy)
{
if (dep[fx] < dep[fy]) swap(x, y), swap(fx, fy);
ans+=getsum(id[x])-getsum(id[fx]-1);
x = f[fx], fx = top[x];
}
if (id[x] > id[y]) swap(x, y);
ans+=getsum(id[y])-getsum(id[x]-1);
return ans;
}
void update_son(int x,ll z)
{
updata(id[x],z);
updata(id[x]+size[x],-z);
}
void init()
{
cnt=0;
memset(son,0,sizeof(son));
memset(size,0,sizeof(size));
memset(f,0,sizeof(f));
memset(dep,0,sizeof(dep));
memset(a, 0, sizeof(a));
memset(sum1, 0, sizeof(sum1));
memset(sum2, 0, sizeof(sum2));
for(ll i=1;i<=n;i++) G[i].clear();
}
int main()
{
ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
cin>>n>>m;
init();
for (int i=1;i<=n;i++) cin >> w[i];
int u,v;
for(int i=0;i<n-1;i++)
{
cin>>u>>v;
addedge(u,v);
}
f[1]=1;
top[1]=1;
dfs1(1,1);
dfs2(1,1);
for(int i=1;i<=n;i++)
{
a[i]=val[i];
updata(i,a[i]-a[i-1]);
}
int flag;
while(m--)
{
int x,y;
cin>>flag;
if(flag==1)
{
cin>>x>>y;
update_val(x,x,y);
}
else if(flag==2)
{
cin>>x>>y;
update_son(x,y);
}
else
{
cin>>x;
cout<<query_val(1,x)<<endl;
}
}
return 0;
}