树分治,设当前树的分治中心为x,其子树分治中心为y,则设father[y]=x,分治下去则可以得到一颗重心树,而且树的深度是logn。
询问操作(x,d),只需要查询重心树上x到重心树根节点上的节点的累加和。假设当前节点是y,那么节点y可以贡献的答案是那些以y为分治中心且到y距离为d-dis(x,y)的节点的总和。当然这样可能会出现重复的情况,重复情况只会出现在包含x的那颗子树上,因此减掉即可。修改操作类似。复杂度O(nlognlogn)
代码
#include<cstdio>
#include<cstring>
#define N 200010
#define LL long long
using namespace std;
int dp,pre[N],p[N],tt[N],vis[N],father[N],s[N],tmp,m;
int n,a,b,i,w[N],L,cnt,tot,len[N],Len[N],start[N],Start[N],v[N];
int deep[N],ss[N][21],fa[N];
int c[N*50];
int min(int a,int b)
{
if (a<b) return a;return b;
}
int lowbit(int x)
{
return x&(-x);
}
void cc(int x,int w,int y)
{
while (x<=L)
{
c[y+x]+=w;
x+=lowbit(x);
}
}
LL sum(int x,int y)
{
LL ans=0;
while (x>0)
{
ans+=c[y+x];
x-=lowbit(x);
}
return ans;
}
void link(int x,int y)
{
dp++;pre[dp]=p[x];p[x]=dp;tt[dp]=y;
}
void gao(int x)
{
int i;
i=p[x];
while (i)
{
if (tt[i]!=fa[x])
{
fa[tt[i]]=x;
deep[tt[i]]=deep[x]+1;
gao(tt[i]);
}
i=pre[i];
}
}
int lca(int x,int y)
{
if(deep[x]>deep[y])x^=y^=x^=y;
int i;
for(i=19;i>=0;i--)
{
if(deep[y]-deep[x]>=(1<<i))
{
y=ss[y][i];
}
}
if(x==y)return x;
for(i=19;i>=0;i--)
{
if(ss[x][i]!=ss[y][i])
{
x=ss[x][i];
y=ss[y][i];
}
}
return fa[x];
}
void getroot(int x,int fa,int sum)
{
int i,flag=0;
i=p[x];s[x]=1;
while (i)
{
if ((!vis[tt[i]])&&(tt[i]!=fa))
{
getroot(tt[i],x,sum);
s[x]+=s[tt[i]];
if (s[tt[i]]>sum/2) flag=1;
}
i=pre[i];
}
if (sum-s[x]>sum/2) flag=1;
if (!flag) tmp=x;
}
void dfs(int x,int fa,int dis)
{
int i;
i=p[x];
if (dis>cnt) cnt=dis;
v[dis]+=w[x];
while (i)
{
if ((!vis[tt[i]])&&(tt[i]!=fa))
dfs(tt[i],x,dis+1);
i=pre[i];
}
}
void clear()
{
int i;
for (i=1;i<=cnt;i++)
v[i]=0;cnt=0;
}
int work(int x,int fa,int sum)
{
int i,root,t;
getroot(x,0,sum);
root=tmp;
father[root]=fa;
i=p[root];
vis[root]=1;
while (i)
{
if (!vis[tt[i]])
{
if (s[root]>s[tt[i]])
t=work(tt[i],root,s[tt[i]]);
else
t=work(tt[i],root,sum-s[root]);
//------dist(root,point in subtree t)--------
dfs(tt[i],0,2);
Len[t]=cnt;
Start[t]=tot;
for (int j=1;j<=cnt;j++)
{
L=cnt;
cc(j,v[j],Start[t]);
}
tot+=cnt;
clear();
}
i=pre[i];
}
vis[root]=0;
//--------dist(root,all point)----------
dfs(root,0,1);
len[root]=cnt;
start[root]=tot;
for (i=1;i<=cnt;i++)
{
L=cnt;
cc(i,v[i],start[root]);
}
tot+=cnt;
clear();
return root;
}
LL query(int x,int d)
{
int y=0,z=x,t;
LL ans=0;
while (x)
{
t=lca(x,z);
t=deep[x]+deep[z]-2*deep[t];
L=len[x];
ans+=sum(min(L,d-t+1),start[x]);
if (y)
{
L=Len[y];
ans-=sum(min(L,d-t+1),Start[y]);
}
y=x;
x=father[x];
}
return ans;
}
void change(int x,int w)
{
int y=0,z=x,t;
while (x)
{
t=lca(x,z);
t=deep[x]+deep[z]-2*deep[t];
L=len[x];
cc(t+1,w,start[x]);
if (y)
{
L=Len[y];
cc(t+1,w,Start[y]);
}
y=x;
x=father[x];
}
}
int main()
{
while (scanf("%d%d",&n,&m)!=EOF)
{
dp=0;memset(p,0,sizeof(p));
for (i=1;i<=tot;i++)
c[i]=0;tot=0;
for (i=1;i<=n;i++)
scanf("%d",&w[i]);
for (i=1;i<n;i++)
{
scanf("%d%d",&a,&b);
link(a,b);
link(b,a);
}
gao(1);
for(i=1;i<=n;i++)
ss[i][0]=fa[i];
for(int h=1;h<20;h++)
{
for(i=1;i<=n;i++)
{
ss[i][h]=ss[ss[i][h-1]][h-1];
}
}
work(1,0,n);
for (i=1;i<=m;i++)
{
getchar();
char ch;
scanf("%c%d%d",&ch,&a,&b);
if (ch=='?')
printf("%I64d\n",query(a,b));
else
{
change(a,b-w[a]);
w[a]=b;
}
}
}
}