http://codevs.cn/problem/1514/
题解:
Splay,因为没办法根据给出的书的编号确定书的位置,也就没办法做到log(n)的查询,所以采用自底向上的伸展方式,那么就需要用pa数组记录节点的上个结点。
算法实现上建立了两个虚拟节点来避免溢出——据HZWER。
加了很多注解。
体会了一下,因为move操作的统一性,所以加上两个虚拟节点后有利于继续维护插入操作的统一性。否则除非给移动到顶和底层的另写函数,没办法用一个move函数解决了否则溢出,因为空节点不能作为根节点。
写了蚱蜢以后才逐渐理解。
第一个用的HZWER的splay,第二个是我写的。
代码:
版本1:
总时间耗费: 1637ms
总内存耗费: 2 kB
总内存耗费: 2 kB
#include<cstdio>
#include<cstring>
using namespace std;
const int INF = 1e9 + 7;
const int maxn = 80000 + 10;
int n, m, root;
int ch[maxn][2], p[maxn], a[maxn], s[maxn], v[maxn], id[maxn];
void update(int k) {
s[k] = s[ch[k][0]] + s[ch[k][1]] + 1;
}
void rotate(int x, int &k) {
int y = p[x], z = p[y];
int d = ch[y][0] == x ? 0 : 1;
int d2 = ch[z][0] == y ? 0 : 1;
if(y == k) k = x; else ch[z][d2] = x;
p[x] = z; p[y] = x; p[ch[x][d^1]] = y;
ch[y][d] = ch[x][d^1]; ch[x][d^1] = y;
update(y); update(x);
}
//自底向上的伸展
void splay(int x, int &k) {
while(x != k) {
int y = p[x], z = p[y];
if(y != k) {
if(ch[y][0] == x ^ ch[z][0] == y) rotate(x, k);
else rotate(y, k);
}
rotate(x, k);
}
}
//类似线段树的build
void build(int l, int r, int pa) {
if(l > r) return;
if(l == r) {
v[l] = a[l]; s[l] = 1; p[l] = pa;
if(l < pa) ch[pa][0] = l; else ch[pa][1] = l;
return;
}
int mid = (l+r) >> 1;
build(l, mid-1, mid); build(mid+1, r, mid);
v[mid] = a[mid]; p[mid] = pa; update(mid);
if(mid < pa) ch[pa][0] = mid;
else ch[pa][1] = mid;
}
//在这个树中查询第rank个
int find(int k, int rank) {
int l = ch[k][0], r = ch[k][1];
if(s[l]+1 == rank) return k;
else if(s[l] >= rank) return find(l, rank);
else return find(r, rank-s[l]-1);
}
//删除排在第k位的
void remove(int k) {
int x, y, z;
x = find(root, k-1); y = find(root, k+1);
splay(x, root); splay(y, ch[x][1]);
z = ch[y][0]; ch[y][0] = 0; p[z] = s[z] = 0;
update(y); update(x);
}
void move(int k, int val) {
int o = id[k], x, y, rank;
splay(o, root);
rank = s[ch[o][0]] + 1;
remove(rank);
//x是要插入的位置的上一个点,y是要插入的位置
//最后插入到y的左儿子也就是取代了y原来的位置
if(val == INF) x = find(root, n), y = find(root, n+1); //插入到底部 此时共n+1个点 成为第n+1个点
else if(val == -INF) x = find(root, 1), y = find(root, 2); //插入到顶部 成为第2个点
else x = find(root, rank+val-1), y = find(root, rank+val); //插入到中间 成为第rank+val个点(rank的值是原来的排名+1)
splay(x, root); splay(y, ch[x][1]);
s[o] = 1; p[o] = y; ch[y][0] = o;
update(y); update(x);
}
int main() {
scanf("%d%d", &n, &m);
//加入1和n+2两个虚拟结点,避免溢出
for(int i = 2; i <= n+1; i++)
scanf("%d", &a[i]), id[a[i]] = i;
build(1, n+2, 0);
root = (n+3) >> 1;
char cmd[10]; int S, T;
for(int i = 1; i <= m; i++) {
scanf("%s%d", cmd, &S); //避免了标准函数读char读到回车什么奇怪的东西
switch(cmd[0]) {
case 'T': move(S, -INF); break;
case 'B': move(S, INF); break;
case 'I': scanf("%d", &T); move(S, T); break;
case 'A': splay(id[S], root); printf("%d\n", s[ch[id[S]][0]]-1); break;
case 'Q': printf("%d\n", v[find(root, S+1)]); break;
}
}
return 0;
}
版本2:
总时间耗费: 1067ms
总内存耗费: 3 MB
#include<cstdio>
#include<cstring>
using namespace std;
const int INF = 1e9 + 7;
const int maxn = 80000 + 10;
int n, m, root;
int ch[maxn][2], p[maxn], a[maxn], s[maxn], v[maxn], id[maxn];
void update(int k) {
s[k] = s[ch[k][0]] + s[ch[k][1]] + 1;
}
void rotate(int& px, int& x, int d) {
int t = ch[x][d]; ch[x][d] = px; ch[px][d^1] = t;
p[x] = p[px]; p[px] = x; p[t] = px; update(px); update(x); px = x;
}
void splay(int x, int& k) {
while(x != k) {
int y = p[x], z = p[y];
int d = ch[y][0] == x ? 0 : 1;
int d2 = ch[z][0] == y ? 0 : 1;
if(y != k) rotate(ch[z][d2], x, d^1); else rotate(k, x, d^1);
}
}
void build(int L, int R, int P, int d) {
if(L == R) { s[L] = 1; p[L] = P; ch[P][d] = L; return; }
int M = (L+R) >> 1;
p[M] = P; ch[P][d] = M;
if(M-1 >= L) build(L, M-1, M, 0);
if(R >= M+1) build(M+1, R, M, 1);
update(M);
}
int find(int k, int rank) {
int l = ch[k][0], r = ch[k][1];
if(s[l]+1 == rank) return k;
else if(s[l] >= rank) return find(l, rank);
else return find(r, rank-s[l]-1);
}
void remove(int k) {
int x, y, z;
x = find(root, k-1); y = find(root, k+1);
splay(x, root); splay(y, ch[x][1]);
z = ch[y][0]; ch[y][0] = 0; p[z] = s[z] = 0;
update(y); update(x);
}
void move(int k, int val) {
int o = id[k], x, y, rank;
splay(o, root);
rank = s[ch[o][0]] + 1;
remove(rank);
if(val == INF) x = find(root, n), y = find(root, n+1);
else if(val == -INF) x = find(root, 1), y = find(root, 2);
else x = find(root, rank+val-1), y = find(root, rank+val);
splay(x, root); splay(y, ch[x][1]);
s[o] = 1; p[o] = y; ch[y][0] = o;
update(y); update(x);
}
int main() {
scanf("%d%d", &n, &m);
for(int i = 2; i <= n+1; i++) scanf("%d", &v[i]), id[v[i]] = i;
build(1, n+2, 0, 1);
root = (n+3) >> 1;
char cmd[10]; int S, T;
for(int i = 1; i <= m; i++) {
scanf("%s%d", cmd, &S);
switch(cmd[0]) {
case 'T': move(S, -INF); break;
case 'B': move(S, INF); break;
case 'I': scanf("%d", &T); move(S, T); break;
case 'A': splay(id[S], root); printf("%d\n", s[ch[id[S]][0]]-1); break;
case 'Q': printf("%d\n", v[find(root, S+1)]); break;
}
}
return 0;
}