ZOJ 2112
题意
给你n个数,有q次操作,每次操作可以修改某一个数,或是求区间第k小值。(多组数据)
样例输入
2
5 3
3 2 1 4 7
Q 1 4 3
C 2 6
Q 2 5 3
5 3
3 2 1 4 7
Q 1 4 3
C 2 6
Q 2 5 3
样例输出
3
6
3
6
SOL
如果不考虑单点修改就是主席树裸题。主席树本质上使用前缀和维护的,查询复杂度为O(1),但修改复杂度为O(n)。如果不用前缀和,查询复杂度为O(n),修改复杂度为O(n)。(这不是废话么…)
考虑在外面套上一层树状数组,使得查询和修改复杂度均为O(logn)。
简单的实现方式
普通的树状数组的每一个点存的都是某几个数的和(和lowbit有关),那么本道题的每一个点存的都是一棵线段树,并且相邻线段树之间用主席树的方式连接。
时间复杂度分析
修改:每次最多对log(n)棵线段树修改,每个节点中存的线段树要修改log(n)次,最多加入(n+m)个数(最开始要放入n个数),所以时间复杂度为O((n+m)lognlogn)。
查询:每次需要累加最多log(n)棵线段树,每棵线段树累加log(n)个节点,所以时间复杂度为O(mlognlogn)。
空间复杂度分析
每次操作修改log(n)棵线段树,每棵线段树修改log(n)个节点,所以空间复杂度为O((n+m)lognlogn)。
从这里可以可以看出时间复杂度和空间复杂度都是nlognlogn级别的,但是在ZOJ上内存限制只有64M,也就是说这种方式会MLE,如何解决?
静态建树+动态修改
分析数据范围可以看出这道题目的m比较小,所以可以选择一种神奇的方法:把查询的区间分为两步,第一步求出原区间,第二步加上修改的增量。第一步显然就是裸的主席树,第二步也就是没有初值的修改,不同之处在于只要修改m就可以了。空间复杂度级数没有变,但是常数大概能小4~6倍,空间刚好可以卡过去。下面就用这种方法具体解释如何实现。
建树
为了自己能够更好的掌握,这里的建树和更新方法都采用裸二叉树的方法。(还有许多方法,比如不需要用递归实现的,以后有空可以学习一下,基本上递归方式能够看懂非递归方式就很简单了)
void build(int l,int r,int &rt)
{
rt = ++tot;
sum[rt] = 0; if (l == r) return;
int m = (l + r) >> 1;
build(l,m,ls[rt]);
build(m+1,r,rs[rt]);
}
更新
void update(int l,int r,int &rt,int last,int p,int delta)
{
rt = ++tot;
ls[rt] = ls[last]; rs[rt] = rs[last];
sum[rt] = sum[last] + delta;
if (l == r) return;
int m = (l + r) >> 1;
if (p <= m) update(l,m,ls[rt],ls[last],p,delta);
else update(m+1,r,rs[rt],rs[last],p,delta);
}
查询
个人感觉最难的部分就是查询。从这里可以看出非递归方式写起来真的非常清爽。。。特别是当树套树一层一层搞不清楚的时候写成递归真的要死。。。
这里的S数组存的是树状数组里面每个节点代表的线段树的根的编号(可以对比:root数组存的是原始数组里面每个节点代表的线段树的根的编号),怎么更新S后面会讲到。use1/2数组存的就是L-1和R的lowbit路径。
int cnt = value2(R) - value1(L-1) + sum[ls[rrt]] - sum[ls[lrt]];
这句话就是我刚才提到的“把查询的区间分为两步,第一步求出原区间,第二步加上修改的增量”。显然value过程求的是增量,sum数组显然是原区间的情况。
int value1(int x)
{
int re = 0;
while (x > 0) {re += sum[ls[use1[x]]]; x -= lowbit(x);}
return re;
}
int value2(int x)
{
int re = 0;
while (x > 0) {re += sum[ls[use2[x]]]; x -= lowbit(x);}
return re;
}
int Query(int L,int R,int k)
{
int lrt = root[L-1];
int rrt = root[R];
int l = 1,r = mm;
for (int i = L - 1;i ; i -= lowbit(i)) use1[i] = S[i];
for (int i = R;i ; i -= lowbit(i)) use2[i] = S[i];
while (l < r)
{
int m = (l + r) >> 1;
int cnt = value2(R) - value1(L-1) + sum[ls[rrt]] - sum[ls[lrt]];
if (k <= cnt)
{
r = m;
for (int i = L - 1;i ; i -= lowbit(i)) use1[i] = ls[use1[i]];
for (int i = R;i ; i -= lowbit(i)) use2[i] = ls[use2[i]];
lrt = ls[lrt]; rrt = ls[rrt];
} else
{
l = m + 1; k = k - cnt;
for (int i = L - 1;i > 0; i -= lowbit(i)) use1[i] = rs[use1[i]];
for (int i = R;i > 0; i -= lowbit(i)) use2[i] = rs[use2[i]];
lrt = rs[lrt]; rrt = rs[rrt];
}
}
return l;
}
修改
S数组一开始全部连到root[0]上表示全部为空。接下来每次修改一个数就新开一棵线段树并且用S记录位置。
void change(int x,int p,int delta)
{
while (x<=n)
{
update(1,k,S[x],S[x],p,delta);
x += lowbit(x);
}
}
完整代码
#include<cmath>
#include<cstdio>
#include<vector>
#include<cstring>
#include<iomanip>
#include<stdlib.h>
#include<iostream>
#include<algorithm>
#define ll long long
#define inf 1000000000
#define mod 1000000007
#define N 2500010
#define M 60010
using namespace std;
struct data{int kind,l,r,k;}query[10010];
char op[10];
int T,n,mm,q,i,tot,k;
int a[M],b[M],sum[N],ls[N],rs[N],root[M],S[M],use1[M],use2[M];
void build(int l,int r,int &rt)
{
rt = ++tot;
sum[rt] = 0; if (l == r) return;
int m = (l + r) >> 1;
build(l,m,ls[rt]);
build(m+1,r,rs[rt]);
}
void update(int l,int r,int &rt,int last,int p,int delta)
{
rt = ++tot;
ls[rt] = ls[last]; rs[rt] = rs[last];
sum[rt] = sum[last] + delta;
if (l == r) return;
int m = (l + r) >> 1;
if (p <= m) update(l,m,ls[rt],ls[last],p,delta);
else update(m+1,r,rs[rt],rs[last],p,delta);
}
int lowbit(int x){return x&(-x);}
int value1(int x)
{
int re = 0;
while (x > 0) {re += sum[ls[use1[x]]]; x -= lowbit(x);}
return re;
}
int value2(int x)
{
int re = 0;
while (x > 0) {re += sum[ls[use2[x]]]; x -= lowbit(x);}
return re;
}
int Query(int L,int R,int k)
{
int lrt = root[L-1];
int rrt = root[R];
int l = 1,r = mm;
for (int i = L - 1;i ; i -= lowbit(i)) use1[i] = S[i];
for (int i = R;i ; i -= lowbit(i)) use2[i] = S[i];
while (l < r)
{
int m = (l + r) >> 1;
int cnt = value2(R) - value1(L-1) + sum[ls[rrt]] - sum[ls[lrt]];
if (k <= cnt)
{
r = m;
for (int i = L - 1;i ; i -= lowbit(i)) use1[i] = ls[use1[i]];
for (int i = R;i ; i -= lowbit(i)) use2[i] = ls[use2[i]];
lrt = ls[lrt]; rrt = ls[rrt];
} else
{
l = m + 1; k = k - cnt;
for (int i = L - 1;i > 0; i -= lowbit(i)) use1[i] = rs[use1[i]];
for (int i = R;i > 0; i -= lowbit(i)) use2[i] = rs[use2[i]];
lrt = rs[lrt]; rrt = rs[rrt];
}
}
return l;
}
void change(int x,int p,int delta)
{
while (x<=n)
{
update(1,k,S[x],S[x],p,delta);
x += lowbit(x);
}
}
int hash(int x)
{
return lower_bound(b+1,b+k+1,x)-b;
}
int main()
{
cin>>T;
while (T--)
{
cin>>n>>q;
for (i = 1;i <= n; i++) scanf("%d",&a[i]);
for (i = 1;i <= n; i++) b[i] = a[i]; k = n;
for (i = 1;i <= q; i++)
{
scanf("%s",op);
if (op[0] == 'Q')
{
query[i].kind = 0;
scanf("%d%d%d",&query[i].l,&query[i].r,&query[i].k);
} else
{
query[i].kind = 1;
scanf("%d%d",&query[i].l,&query[i].r);
b[++k] = query[i].r;
}
}
sort(b+1,b+k+1);
k = unique(b+1,b+k+1) - (b+1);
tot = 0;
build(1,k,root[0]);
for (i = 1;i <= n; i++) update(1,k,root[i],root[i-1],hash(a[i]),1);
for (i = 1;i <= n; i++) S[i] = root[0];
mm = k;
for (i = 1;i <= q; i++)
{
if (query[i].kind == 0)
printf("%d\n",b[Query(query[i].l,query[i].r,query[i].k)]);
else
{
change(query[i].l,hash(a[query[i].l]),-1);
change(query[i].l,hash(query[i].r),1);
a[query[i].l] = query[i].r;
}
}
}
return 0;
}