Description
给出一棵\(n(n\leq8\times10^4)\)个点的带点权的树,进行\(m(m\leq8\times10^4)\)次操作,操作有两种:
- 修改一个点的点权。
- 询问路径\((u,v)\)上第\(k\)大的点权。若路径上的点不足\(k\)个输出
invalid request!
。
Solution
带修改的可持久化线段树。
首先对于每个节点\(u\)建立一棵线段树记录路径\((u,rt)\)上的权值分布。考虑修改一个点的权值对这些线段树有什么影响。当\(u\)的权值由\(val\)变为\(val'\)后,子树\(u\)中的所有点的线段树中都有\(cnt[val]-1,cnt[val']+1\)。
对每个点再建立一棵线段树记录修改。修改子树在DFS序上相当于修改区间,差分后变为两个单点修改。于是我们要对于这些线段树进行单点修改,前缀查询;而这可以用树状数组实现。于是我们按DFS序建立树状数组套线段树就可以维护修改带来的影响。当我们需要求路径\((u,rt)\)上的权值分布时,用原线段树加上修改线段树即可。
时间复杂度\(O(mlog^3n)\)。
Code
//[CTSC2008]网络管理
#include <algorithm>
#include <cstdio>
#include <vector>
using namespace std;
inline char gc()
{
static char now[1<<16],*s,*t;
if(s==t) {t=(s=now)+fread(now,1,1<<16,stdin); if(s==t) return EOF;}
return *s++;
}
inline int read()
{
int x=0; char ch=gc();
while(ch<'0'||'9'<ch) ch=gc();
while('0'<=ch&&ch<='9') x=x*10+ch-'0',ch=gc();
return x;
}
const int N=8e4+10;
int n,m,w[N];
struct optR{int k,u,v;} seq[N];
int wCnt,map[N<<1];
void discrete()
{
int cnt=0;
for(int i=1;i<=n;i++) map[++cnt]=w[i];
for(int i=1;i<=m;i++) if(seq[i].k==0) map[++cnt]=seq[i].v;
sort(map+1,map+cnt+1); wCnt=unique(map+1,map+cnt+1)-map-1;
for(int i=1;i<=n;i++) w[i]=lower_bound(map+1,map+wCnt+1,w[i])-map;
for(int i=1;i<=m;i++) if(seq[i].k==0) seq[i].v=lower_bound(map+1,map+wCnt+1,seq[i].v)-map;
}
const int N1=2e7;
int ndCnt,rt1[N],rt2[N],ch[N1][2],sum[N1];
void trAdd(int t,int x,int v);
int trSum(int t,int L,int R);
void ndCopy(int p,int q) {ch[q][0]=ch[p][0],ch[q][1]=ch[p][1],sum[q]=sum[p];}
void ins1(int &p,int L0,int R0,int x)
{
ndCopy(p,++ndCnt); sum[p=ndCnt]++;
if(L0==R0) return;
int mid=L0+R0>>1;
if(x<=mid) ins1(ch[p][0],L0,mid,x);
else ins1(ch[p][1],mid+1,R0,x);
}
int t1,t2,t3,t4;
int query1(int p1,int p2,int p3,int p4,int L0,int R0,int k)
{
if(L0==R0) return map[L0];
int mid=L0+R0>>1,sumL=0;
sumL+=sum[ch[p1][0]]+trSum(t1,L0,mid)+sum[ch[p2][0]]+trSum(t2,L0,mid);
sumL-=sum[ch[p3][0]]+trSum(t3,L0,mid)+sum[ch[p4][0]]+trSum(t4,L0,mid);
if(sumL>=k) return query1(ch[p1][0],ch[p2][0],ch[p3][0],ch[p4][0],L0,mid,k);
else return query1(ch[p1][1],ch[p2][1],ch[p3][1],ch[p4][1],mid+1,R0,k-sumL);
}
void ins2(int &p,int L0,int R0,int x,int v)
{
if(!p) p=++ndCnt; sum[p]+=v;
if(L0==R0) return;
int mid=L0+R0>>1;
if(x<=mid) ins2(ch[p][0],L0,mid,x,v);
else ins2(ch[p][1],mid+1,R0,x,v);
}
int query2(int p,int L0,int R0,int optL,int optR)
{
if(optL<=L0&&R0<=optR) return sum[p];
int mid=L0+R0>>1,r=0;
if(optL<=mid) r+=query2(ch[p][0],L0,mid,optL,optR);
if(mid<optR) r+=query2(ch[p][1],mid+1,R0,optL,optR);
return r;
}
void trAdd(int t,int x,int v) {while(t<=n) ins2(rt2[t],1,wCnt,x,v),t+=t&(-t);}
int trSum(int t,int L,int R)
{
int r=0;
while(t) r+=query2(rt2[t],1,wCnt,L,R),t-=t&(-t);
return r;
}
vector<int> ed[N];
void edAdd(int u,int v) {ed[u].push_back(v),ed[v].push_back(u);}
int fa[N][20],dpt[N]; int dfCnt,fr[N],to[N];
void dfs(int u)
{
fr[u]=++dfCnt;
ins1(rt1[u]=rt1[fa[u][0]],1,wCnt,w[u]);
for(int k=1;fa[u][k-1];k++) fa[u][k]=fa[fa[u][k-1]][k-1];
for(int i=0;i<ed[u].size();i++)
{
int v=ed[u][i];
if(v==fa[u][0]) continue;
fa[v][0]=u,dpt[v]=dpt[u]+1;
dfs(v);
}
to[u]=dfCnt;
}
int lca(int u,int v)
{
if(dpt[u]<dpt[v]) swap(u,v);
for(int k=17;k>=0;k--) if(dpt[fa[u][k]]>=dpt[v]) u=fa[u][k];
if(u==v) return u;
for(int k=17;k>=0;k--) if(fa[u][k]!=fa[v][k]) u=fa[u][k],v=fa[v][k];
return fa[u][0];
}
int main()
{
n=read(),m=read();
for(int i=1;i<=n;i++) w[i]=read();
for(int i=1;i<=n-1;i++) edAdd(read(),read());
for(int i=1;i<=m;i++) seq[i].k=read(),seq[i].u=read(),seq[i].v=read();
discrete();
dpt[1]=1,dfs(1);
for(int i=1;i<=m;i++)
{
int k=seq[i].k,u=seq[i].u,v=seq[i].v;
if(k==0)
{
trAdd(fr[u],w[u],-1),trAdd(to[u]+1,w[u],1);
w[u]=v; trAdd(fr[u],v,1),trAdd(to[u]+1,v,-1);
}
else
{
int u1=lca(u,v),v1=fa[u1][0],sum1=0;
t1=fr[u],t2=fr[v],t3=fr[u1],t4=fr[v1];
sum1+=sum[rt1[u]]+trSum(t1,1,wCnt)+sum[rt1[v]]+trSum(t2,1,wCnt);
sum1-=sum[rt1[u1]]+trSum(t3,1,wCnt)+sum[rt1[v1]]+trSum(t4,1,wCnt);
if(sum1<k) puts("invalid request!");
else printf("%d\n",query1(rt1[u],rt1[v],rt1[u1],rt1[v1],1,wCnt,sum1-k+1));
}
}
return 0;
}
P.S.
姑且是把之前鸽了的题解都补完了...