比赛做到一道虚树题,发现原来强行背下的建虚树方法忘了,只能爆零了。。比赛完后重新学了一遍虚树,终于貌似理解了(原本虚树学习笔记已更新)。
题意:给一棵树,树上点有点权,现给出一些询问,询问k个点及所有在k个点之间点对上的点的点权和,sigma k<=1e6,同时还要支持单点修改。
做法:每次询问建一次虚树,维护每个点到根节点的路径上的点权之和,修改直接用树状数组差分,总复杂度nlog。
代码1:(标准用栈建虚树)
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10;
int n,q,a[N];
vector<int>mp[N];
int fa[N],son[N],sz[N],top[N],dep[N],tim,dfn[N],id[N];
int qry[N],qnum;stack<int>stk;
ll tr[N],dis[N];
void read(int &x)
{
char c=getchar();x=0;bool r=0;
while(!isdigit(c))
{
if(c=='-')r=1;
c=getchar();
}
while(isdigit(c))x=x*10+c-48,c=getchar();
if(r)x=-x;
}
bool cmp(int x,int y)
{return id[x]<id[y];}
void dfs(int pos,int f,int dp)
{
fa[pos]=f,sz[pos]=1,dep[pos]=dp;
for(int i=0;i<mp[pos].size();i++)
{
if(mp[pos][i]!=f)
{
dfs(mp[pos][i],pos,dp+1);
if(sz[mp[pos][i]]>sz[son[pos]])son[pos]=mp[pos][i];
sz[pos]+=sz[mp[pos][i]];
}
}
}
void Dfs(int pos,int tp,ll D)
{
top[pos]=tp,dfn[++tim]=pos,id[pos]=tim,dis[pos]=D+a[pos];
if(son[pos])
{
Dfs(son[pos],tp,D+a[pos]);
for(int i=0;i<mp[pos].size();i++)
{
if(mp[pos][i]!=fa[pos]&&mp[pos][i]!=son[pos])
Dfs(mp[pos][i],mp[pos][i],D+a[pos]);
}
}
}
int lca(int x,int y)
{
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
x=fa[fx],fx=top[x];
}
if(dep[x]<dep[y])return x;
else return y;
}
void ad(int x,int w)
{for(;x<=tim;x+=x&-x)tr[x]+=w;}
void add(int l,int r,int w)
{ad(l,w),ad(r+1,-w);}
ll ask(int x)
{
ll res=0;
for(;x;x-=x&-x)res+=tr[x];
return res;
}
int main()
{
int u,v,l,x1,x2;char op;
ll res;
// freopen("ex_kaihuang.in","r",stdin);
// freopen("my.out","w",stdout);
read(n),read(q);
for(int i=1;i<=n;i++)
read(a[i]);
for(int i=1;i<n;i++)
{
read(u),read(v);
mp[u].push_back(v);
mp[v].push_back(u);
}
dfs(1,0,1),Dfs(1,1,0);
while(q--)
{
scanf(" %c",&op);
if(op=='C')
{
read(u),read(v);
add(id[u],id[u]+sz[u]-1,v-a[u]);
a[u]=v;
}
else
{
qnum=0,res=0;
while(1)
{
read(u);
if(u==0)break;
qry[++qnum]=u;
}
sort(qry+1,qry+qnum+1,cmp);
for(int i=1;i<=qnum;i++)
{
if(stk.empty())stk.push(qry[i]);
else
{
l=lca(stk.top(),qry[i]);
while(!stk.empty())
{
if(dep[stk.top()]<=dep[l])break;
else
{
x1=stk.top(),stk.pop();
if(!stk.empty())
{
x2=stk.top();
if(dep[x2]<=dep[l])res+=dis[x1]-dis[l]+ask(id[x1])-ask(id[l]);
else res+=dis[x1]-dis[x2]+ask(id[x1])-ask(id[x2]);
}
else res+=dis[x1]-dis[l]+ask(id[x1])-ask(id[l]);
}
}
if(stk.empty())stk.push(l),stk.push(qry[i]);
else
{
if(dep[stk.top()]!=dep[l])stk.push(l),stk.push(qry[i]);
else stk.push(qry[i]);
}
}
}
while(!stk.empty())
{
x1=stk.top(),stk.pop();
if(!stk.empty())
{
x2=stk.top();
res+=dis[x1]-dis[x2]+ask(id[x1])-ask(id[x2]);
}
else res+=a[x1],assert(x1==lca(qry[1],qry[qnum]));
}
printf("%lld\n",res);
}
}
}
代码2:%%%fx巨佬,这道题发现并不需要把虚树真的建出来,只要求每个点到它和sort完恰在它之前那个点的lca的距离和即可,不过注意要先把虚树的根(即sort完第一个点与最后一个点lca)加入。
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int N=1e5+10;
int n,q,a[N];
vector<int>mp[N];
int fa[N],son[N],sz[N],top[N],dep[N],tim,dfn[N],id[N];
int qry[N],qnum;stack<int>stk;
ll tr[N],dis[N];
void read(int &x)
{
char c=getchar();x=0;bool r=0;
while(!isdigit(c))
{
if(c=='-')r=1;
c=getchar();
}
while(isdigit(c))x=x*10+c-48,c=getchar();
if(r)x=-x;
}
bool cmp(int x,int y)
{return id[x]<id[y];}
void dfs(int pos,int f,int dp)
{
fa[pos]=f,sz[pos]=1,dep[pos]=dp;
for(int i=0;i<mp[pos].size();i++)
{
if(mp[pos][i]!=f)
{
dfs(mp[pos][i],pos,dp+1);
if(sz[mp[pos][i]]>sz[son[pos]])son[pos]=mp[pos][i];
sz[pos]+=sz[mp[pos][i]];
}
}
}
void Dfs(int pos,int tp,ll D)
{
top[pos]=tp,dfn[++tim]=pos,id[pos]=tim,dis[pos]=D+a[pos];
if(son[pos])
{
Dfs(son[pos],tp,D+a[pos]);
for(int i=0;i<mp[pos].size();i++)
{
if(mp[pos][i]!=fa[pos]&&mp[pos][i]!=son[pos])
Dfs(mp[pos][i],mp[pos][i],D+a[pos]);
}
}
}
int lca(int x,int y)
{
int fx=top[x],fy=top[y];
while(fx!=fy)
{
if(dep[fx]<dep[fy])swap(x,y),swap(fx,fy);
x=fa[fx],fx=top[x];
}
if(dep[x]<dep[y])return x;
else return y;
}
void ad(int x,int w)
{for(;x<=tim;x+=x&-x)tr[x]+=w;}
void add(int l,int r,int w)
{ad(l,w),ad(r+1,-w);}
ll ask(int x)
{
ll res=0;
for(;x;x-=x&-x)res+=tr[x];
return res;
}
int main()
{
int u,v,l;char op;
ll res;
read(n),read(q);
for(int i=1;i<=n;i++)
read(a[i]);
for(int i=1;i<n;i++)
{
read(u),read(v);
mp[u].push_back(v);
mp[v].push_back(u);
}
dfs(1,0,1),Dfs(1,1,0);
while(q--)
{
scanf(" %c",&op);
if(op=='C')
{
read(u),read(v);
add(id[u],id[u]+sz[u]-1,v-a[u]);
a[u]=v;
}
else
{
qnum=0;
while(1)
{
read(u);
if(u==0)break;
qry[++qnum]=u;
}
sort(qry+1,qry+qnum+1,cmp);
qry[0]=lca(qry[1],qry[qnum]);
res=a[qry[0]];
for(int i=1;i<=qnum;i++)
{
l=lca(qry[i],qry[i-1]);
res+=dis[qry[i]]-dis[l]+ask(id[qry[i]])-ask(id[l]);
}
printf("%lld\n",res);
}
}
}