wls 有三棵树,树上每个节点都有一个值 ai,现在有 2 种操作:
- 将一条链上的所有节点的值开根号向下取整;
- 求一条链上值的和;
链的定义是两点之间的最短路。
Input
第一行两个数 n, q 分别代表树上点的数量和操作数量。
第二行 n 个整数,第 i 个数代表第 i 个点的值 ai。
接下来 n − 1 行, 每行两个整数 u, v 代表 u,v 之间有一条边。数据保证点两两联通。
接下来 q 行,每行有个整数 op, u, v,op = 0 表示将 u, v 这条链上所有的点的值开根号向下取整,op = 1表示询问 u,v 这条链上的值的和。
1 ≤ n, q ≤ 100, 000
0 ≤ ai ≤ 1, 000, 000, 000
Output
对于每一组 op = 2 的询问,输出一行一个值表示答案。
Sample Input
4 4
2 3 4 5
1 2
2 3
2 4
0 3 4
0 1 3
1 2 3
1 1 4
Sample Output
2
4
区间求和,区间开根号暴力修改即可
#include<bits/stdc++.h>
#define ll long long
using namespace std;
vector<int>g[100005];
int fa[100005],pos[100005],siz[100005],son[100005],h[100005],top[100005],cnt=0,n;
ll A[100005],val[100005];
void dfs1(int u,int f)
{
int i,v;
siz[u]=1;
son[u]=0;
fa[u]=f;
h[u]=h[f]+1;
for(i=0;i<g[u].size();i++)
{
v=g[u][i];
if(v!=f)
{
dfs1(v,u);
siz[u]+=siz[v];
if(siz[son[u]]<siz[v]) son[u]=v;
}
}
}
void dfs2(int u,int f,int k)
{
int i,v;
top[u]=k;
pos[u]=++cnt;
A[cnt]=val[u];
if(son[u]) dfs2(son[u],u,k);
for(i=0;i<g[u].size();i++)
{
v=g[u][i];
if(v!=f&&v!=son[u]) dfs2(v,u,v);
}
}
ll T[200010];
void update(int root,int l,int r,int k,ll v){
if(l==r)
{
T[root]=v;
return;
}
int mid=(l+r)/2;
if(k<=mid) update(root*2,l,mid,k,v);
else update(root*2+1,mid+1,r,k,v);
T[root]=T[root*2]+T[root*2+1];
}
ll qsum(int root,int l,int r,int L,int R)
{
if(L==l&&r==R) return T[root];
int mid=(l+r)/2;
if(R<=mid) return qsum(root*2,l,mid,L,R);
else if(L>mid) return qsum(root*2+1,mid+1,r,L,R);
else return qsum(root*2,l,mid,L,mid)+qsum(root*2+1,mid+1,r,mid+1,R);
}
ll fsum(int u,int v)
{
ll ans=0;
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ans+=qsum(1,1,n,pos[top[u]],pos[u]);
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ans+=qsum(1,1,n,pos[v],pos[u]);
return ans;
}
void ch(int root,int l,int r,int L,int R)
{
if(l==r)
{
T[root]=(ll)sqrt(T[root]);
return;
}
int mid=(l+r)/2;
if(mid>=L)
ch(root*2,l,mid,L,R);
if(mid+1<=R)
ch(root*2+1,mid+1,r,L,R);
T[root]=T[root*2]+T[root*2+1];
}
void op(int u,int v)
{
while(top[u]!=top[v])
{
if(h[top[u]]<h[top[v]]) swap(u,v);
ch(1,1,n,pos[top[u]],pos[u]);
u=fa[top[u]];
}
if(h[u]<h[v]) swap(u,v);
ch(1,1,n,pos[v],pos[u]);
}
int main(){
int a,b,q,i;
int s;
scanf("%d%d",&n,&q);
for(i=1;i<=n;i++)
scanf("%lld",&val[i]);
for(i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
g[a].push_back(b);
g[b].push_back(a);
}
dfs1(1,0);
dfs2(1,0,1);
for(i=1;i<=n;i++)
update(1,1,n,i,A[i]);
while(q--)
{
scanf("%d%d%d",&s,&a,&b);
if(s==0) op(a,b);
else printf("%lld\n",fsum(a,b));
}
return 0;
}