题目
给定一棵N 个节点的树,标号从1~N。每个点有一个权值。要求维护两种操作:
1. C i x(0<=x<2^31) 表示将i 点权值变为x
2. Q i j x(0<=x<2^31) 表示询问i 到j 的路径上有多少个值为x 的节点
比较简单的想法就是按权值建N+Q棵可持久化线段树,第i棵线段树里记录权值为a[i]的节点的区间和,节点要用链剖后的dfs序来做编号,修改时就是将原权值所在线段树中的点删除,在新权值线段树里插入,查询就沿重链往上跳,先跳深度大的,然后在线段树里查询就行了。
时间复杂度:(N+Q)(log(N+Q))^2
空间复杂度:(N+Q)(log(N+Q))
贴代码
#include<iostream>
#include<cstdio>
#include<algorithm>
#define N 300001
using namespace std;
int n,m,sum,ans;
int fa[N],q[N][4],g[N],a[N+N][2],f[N*36][3],bz[N],h[N],v[N],b[N][2],c[N],w[N];
void ins(int x,int y){
static int sum=0;
a[++sum][0]=y,a[sum][1]=g[x],g[x]=sum;
}
void init(){
static char x;
static int y,z;
scanf("%d %d",&n,&m);
for (int i=1;i<=n;i++)
scanf("%d",&v[i]),c[++c[0]]=v[i];
for (int i=1;i<n;i++)
scanf("%d %d",&y,&z),ins(y,z),ins(z,y);
for (int i=1;i<=m;i++){
scanf(" %c %d %d",&x,&q[i][1],&q[i][2]);
if (x=='Q')
scanf("%d",&q[i][3]),q[i][0]=1;
else
c[++c[0]]=q[i][2];
}
}
void dfs(int x){
for (int i=g[x];i;i=a[i][1])
if (a[i][0]!=fa[x]){
fa[a[i][0]]=x;
dfs(a[i][0]);
if (b[x][0]<b[a[i][0]][0]+1)
b[x][0]=b[a[i][0]][0]+1,b[x][1]=a[i][0];
}
}
void dfs1(int x){
static int sum=0;
bz[x]=++sum;
if (b[x][1])
h[b[x][1]]=h[x],dfs1(b[x][1]);
for (int i=g[x];i;i=a[i][1])
if (a[i][0]!=fa[x]&&a[i][0]!=b[x][1])
h[a[i][0]]=a[i][0],dfs1(a[i][0]);
}
void change(int l,int r,int s,int ll,int z){
f[s][2]+=z;
if (l==r)return;
static int ss;
if ((ss=(l+r)/2)>=ll){
if (!f[s][0])f[s][0]=++sum;
change(l,ss,f[s][0],ll,z);
}else{
if (!f[s][1])f[s][1]=++sum;
change(ss+1,r,f[s][1],ll,z);
}
}
int er(int x){
static int l,r,mid;
l=1,r=c[0];
while (l<=r)if (c[mid=(l+r)/2]<x)l=mid+1;else r=mid-1;
if (l<1)l=1;
if (l>c[0])l=c[0];
while (c[l]<x&&l<c[0])l++;
while (c[l]>x&&l>1)l--;
return l;
}
void pre(){
static int x;
sort(c+1,c+c[0]+1);
x=c[0],c[0]=1;
for (int i=2;i<=x;i++)
if (c[c[0]]!=c[i])c[++c[0]]=c[i];
dfs(1);
h[1]=1;
dfs1(1);
for (int i=1;i<=c[0];i++)
w[i]=++sum;
for (int i=1;i<=n;i++)
change(1,n,w[er(v[i])],bz[i],1);
}
void find(int l,int r,int s,int ll,int rr){
if (!f[s][2])return;
if (rr<l||r<ll)return;
if (ll<=l&&r<=rr){
ans+=f[s][2];
return;
}
find(l,(l+r)/2,f[s][0],ll,rr),find((l+r)/2+1,r,f[s][1],ll,rr);
}
void up(int x,int y,int z){
static int ss;
ss=z;
z=er(z);
if (c[z]!=ss)return;
while (h[x]!=h[y]){
if (bz[h[x]]<bz[h[y]])swap(x,y);
find(1,n,w[z],bz[h[x]],bz[x]);
x=fa[h[x]];
}
if (bz[x]<bz[y])swap(x,y);
find(1,n,w[z],bz[y],bz[x]);
}
void work(){
for (int i=1;i<=m;i++)
if (q[i][0]){
ans=0;
up(q[i][1],q[i][2],q[i][3]);
printf("%d\n",ans);
}else{
change(1,n,w[er(v[q[i][1]])],bz[q[i][1]],-1);
v[q[i][1]]=q[i][2];
change(1,n,w[er(q[i][2])],bz[q[i][1]],1);
}
}
int main(){
init();
pre();
work();
return 0;
}