代码来自洛谷日报,感谢洛谷
本篇文章没有splay的教学,只有代码的使用说明。
直接贴上代码:
这个不带reverse的模板
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
const int maxn = 2e5+7;
int n, op, x;
int ch[maxn][2], fa[maxn], val[maxn], cnt[maxn], siz[maxn], ncnt, root;
bool chk(int x)
{
return ch[fa[x]][1]==x;
}
void pushup(int x)
{
siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}
void rotate(int x)
{
int y = fa[x], z = fa[y], k = chk(x), w = ch[x][k^1];
ch[y][k] = w;
fa[w] = y;
ch[z][chk(y)] = x;
fa[x] = z;
ch[x][k^1] = y;
fa[y] = x;
pushup(y);
pushup(x);
}
void splay(int x, int goal = 0)
{
while(fa[x] != goal)
{
int y = fa[x], z = fa[y];
if (z != goal)
{
if(chk(x) == chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if(!goal) root = x;
}
void insert(int x)
{
int cur = root, p = 0;
while(cur && val[cur] != x)
{
p = cur;
cur = ch[cur][x > val[cur]];
}
if(cur) cnt[cur]++;
else
{
cur = ++ncnt;
if(p) ch[p][x > val[p]] = cur;
ch[cur][0] = ch[cur][1] = 0;
fa[cur] = p;
val[cur] = x;
cnt[cur] = siz[cur] = 1;
}
splay(cur);
}
void find(int x)
{
int cur = root;
while(ch[cur][x > val[cur]] && x != val[cur])
{
cur = ch[cur][x > val[cur]];
}
splay(cur);
}
int kth(int k)
{
int cur = root;
while(true)
{
if(ch[cur][0] && k <= siz[ch[cur][0]])
{
cur = ch[cur][0];
}
else if(k > siz[ch[cur][0]] + cnt[cur])
{
k -= siz[ch[cur][0]] + cnt[cur];
cur = ch[cur][1];
}
else
{
return cur;
}
}
}
int pre(int x)
{
find(x);
if(val[root] < x) return root;
int cur = ch[root][0];
while(ch[cur][1]) cur = ch[cur][1];
return cur;
}
int succ(int x)
{
find(x);
if (val[root] > x) return root;
int cur = ch[root][1];
while(ch[cur][0]) cur = ch[cur][0];
return cur;
}
void remove(int x)
{
int last = pre(x), next = succ(x);
splay(last);
splay(next, last);
int del = ch[next][0];
if(cnt[del] > 1)
{
cnt[del]--;
splay(del);
}
else ch[next][0] = 0;
}
int main()
{
scanf("%d", &n);
insert(INF);
insert(-INF);
while(n--)
{
scanf("%d%d", &op, &x);
switch (op)
{
case 1: insert(x); break;
case 2: remove(x); break;
case 3: find(x); printf("%d\n", siz[ch[root][0]]); break;
case 4: printf("%d\n", val[kth(x+1)]); break;
case 5: printf("%d\n", val[pre(x)]); break;
case 6: printf("%d\n", val[succ(x)]); break;
}
}
}
-
insert(x),插入x
-
remove(x),删除x(若有多个相同的,只删除一个)
-
find(x),将x伸展到根,则siz[ch[root][0]]就是x左子树的大小,即x的排名
-
kth(x+1),查询排名为x的数,val[kth(x+1)]为排名为x的数的值
-
pre(x),查询x的前驱,前驱定义为小于xx,且最大的数
-
succ(x),查询x的后继,后继定义为大于xx,且最小的数
这个是带reverse的模板:
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+7;
int ch[maxn][2], fa[maxn], val[maxn], cnt[maxn], siz[maxn], rev[maxn], root, ncnt;
int n, m, x, y;
bool chk(int x)
{
return ch[fa[x]][1] == x;
}
void pushup(int x)
{
siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}
void pushdown(int x)
{
if(rev[x])
{
swap(ch[x][0], ch[x][1]);
rev[ch[x][0]] ^= 1;
rev[ch[x][1]] ^= 1;
rev[x] = 0;
}
}
void rotate(int x)
{
int y = fa[x], z = fa[y], k = chk(x), w = ch[x][k^1];
ch[y][k] = w;
fa[w] = y;
ch[z][chk(y)] = x;
fa[x] = z;
ch[x][k^1] = y;
fa[y] = x;
pushup(y);
pushup(x);
}
void splay(int x, int goal = 0)
{
while(fa[x] != goal)
{
int y = fa[x], z = fa[y];
if(z != goal)
{
if(chk(x) == chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if(!goal) root = x;
}
void insert(int x)
{
int cur = root, p = 0;
while(cur && val[cur] != x)
{
p = cur;
cur = ch[cur][x > val[cur]];
}
if(cur) cnt[cur]++;
else
{
cur = ++ncnt;
if(p) ch[p][x > val[p]] = cur;
ch[cur][0] = ch[cur][1] = 0;
fa[cur] = p;
val[cur] = x;
cnt[cur] = siz[cur] = 1;
}
splay(cur);
}
void find(int x)
{
int cur = root;
while(ch[cur][x > val[cur]] && val[cur] != x)
{
cur = ch[cur][x > val[cur]];
}
splay(cur);
}
int kth(int k)
{
int cur = root;
while (true)
{
pushdown(cur);
if (ch[cur][0] && k <= siz[ch[cur][0]])
{
cur = ch[cur][0];
}
else if (k > siz[ch[cur][0]] + cnt[cur])
{
k -= siz[ch[cur][0]] + cnt[cur];
cur = ch[cur][1];
}
else
{
return cur;
}
}
}
void reverse(int l, int r)
{
int x = kth(l), y = kth(r+2);
splay(x);
splay(y, x);
rev[ch[y][0]] ^= 1;
}
int pre(int x)
{
find(x);
if(val[root] < x) return root;
int cur = ch[root][0];
while(ch[cur][1]) cur = ch[cur][1];
return cur;
}
int succ(int x)
{
find(x);
if(val[root] > x) return root;
int cur = ch[root][1];
while(ch[cur][0]) cur = ch[cur][0];
return cur;
}
void output(int x)
{
pushdown(x);
if(ch[x][0]) output(ch[x][0]);
if(val[x] && val[x] <= n) printf("%d ", val[x]);
if(ch[x][1]) output(ch[x][1]);
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 0; i <= n+1; i++) insert(i);
while (m--)
{
scanf("%d%d", &x, &y);
reverse(x, y);
}
output(root);
return 0;
}
- reverse(l,r),区间[l,r]翻转
- output(x),从根x开始中序遍历,输出splay tree上的所有数
下面这个是splay维护区间树:
#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
const int maxn = 1e6+7;
int n, m, arr[maxn], c, x, y, z;
char buf[32];
int siz[maxn], sum[maxn], upd[maxn], rev[maxn], la[maxn], ra[maxn], gss[maxn];
int val[maxn], ch[maxn][2], fa[maxn], ncnt, root;
queue<int> q;
void recycle(int x)
{
if (ch[x][0]) recycle(ch[x][0]);
if (ch[x][1]) recycle(ch[x][1]);
q.push(x);
}
inline int newNode(int x)
{
int cur;
if (q.empty()) cur = ++ncnt;
else cur = q.front(), q.pop();
ch[cur][0] = ch[cur][1] = fa[cur] = 0;
val[cur] = sum[cur] = gss[cur] = x;
la[cur] = ra[cur] = max(0, x);
upd[cur] = rev[cur] = 0;
siz[cur] = 1;
return cur;
}
inline bool chk(int x)
{
return ch[fa[x]][1] == x;
}
inline void pushup(int x)
{
int l = ch[x][0], r = ch[x][1];
siz[x] = siz[l] + siz[r] + 1;
sum[x] = sum[l] + sum[r] + val[x];
// 这里和线段树不同,线段树只有叶子上有权值,平衡树上所有点都有,必须+val[x]
la[x] = max(la[l], sum[l] + val[x] + la[r]);
ra[x] = max(ra[r], sum[r] + val[x] + ra[l]);
gss[x] = max(ra[l] + val[x] + la[r], max(gss[l], gss[r]));
}
inline void rotate(int x)
{
int y = fa[x], z = fa[y], k = chk(x), w = ch[x][k^1];
ch[y][k] = w;
fa[w] = y;
ch[z][chk(y)] = x;
fa[x] = z;
ch[x][k^1] = y;
fa[y] = x;
pushup(y);
pushup(x);
}
inline void pushdown(int x)
{
int l = ch[x][0], r = ch[x][1];
if(upd[x])
{
upd[x] = rev[x] = 0;
if(l)
{
upd[l] = 1;
val[l] = val[x];
sum[l] = val[x] * siz[l];
la[l] = ra[l] = max(sum[l], 0);
gss[l] = val[x] < 0 ? val[x] : sum[l];
}
if(r)
{
upd[r] = 1;
val[r] = val[x];
sum[r] = val[x] * siz[r];
la[r] = ra[r] = max(sum[r], 0);
gss[r] = val[x] < 0 ? val[x] : sum[r];
}
}
if(rev[x])
{
rev[l] ^= 1;
rev[r] ^= 1;
rev[x] = 0;
swap(la[l], ra[l]);
swap(la[r], ra[r]);
swap(ch[l][0], ch[l][1]);
swap(ch[r][0], ch[r][1]);
}
}
inline void splay(int x, int goal = 0)
{
while(fa[x] != goal) {
int y = fa[x], z = fa[y];
if (z != goal)
{
if (chk(x) == chk(y)) rotate(y);
else rotate(x);
}
rotate(x);
}
if (!goal) root = x;
}
int build(int l, int r, int *arr)
{
if (l > r) return 0;
int mid = (l+r)>>1, cur = newNode(arr[mid]);
if (l == r) return cur;
if ((ch[cur][0] = build(l, mid-1, arr))) fa[ch[cur][0]] = cur;
if ((ch[cur][1] = build(mid+1, r, arr))) fa[ch[cur][1]] = cur;
pushup(cur);
return cur;
}
inline int kth(int k)
{
int cur = root;
while(true)
{
pushdown(cur);
if(ch[cur][0] && k <= siz[ch[cur][0]])
{
cur = ch[cur][0];
}
else if (k > siz[ch[cur][0]] + 1) {
k -= siz[ch[cur][0]] + 1;
cur = ch[cur][1];
}
else
{
return cur;
}
}
}
inline void insert(int x, int y)
{
int u = kth(x+1), v = kth(x+2);
splay(u);
splay(v, u);
ch[v][0] = y;
fa[y] = v;
pushup(v);
pushup(u);
}
inline int qsum(int x, int y)
{
int u = kth(x), v = kth(x+y+1);
splay(u);
splay(v, u);
return sum[ch[v][0]];
}
inline int qgss()
{
return gss[root];
}
inline void remove(int x, int y)
{
int u = kth(x), v = kth(x+y+1);
splay(u);
splay(v, u);
recycle(ch[v][0]);
ch[v][0] = 0;
pushup(v);
pushup(u);
}
inline void reverse(int x, int y)
{
int u = kth(x), v = kth(x+y+1);
splay(u);
splay(v, u);
int w = ch[v][0];
if (!upd[w])
{
rev[w] ^= 1;
swap(ch[w][0], ch[w][1]);
swap(la[w], ra[w]);
pushup(v);
pushup(u);
}
}
inline void update(int x, int y, int z)
{
int u = kth(x), v = kth(x+y+1);
splay(u);
splay(v, u);
int w = ch[v][0];
upd[w] = 1;
val[w] = z;
sum[w] = siz[w] * z;
la[w] = ra[w] = max(0, sum[w]);
gss[w] = z < 0 ? z : sum[w];
pushup(v);
pushup(u);
}
int main() {
scanf("%d%d", &n, &m);
for (int i = 2; i <= n+1; i++)
{
scanf("%d", arr+i);
}
gss[0] = val[0] = -INF;
arr[1] = arr[n += 2] = -INF;
build(1, n, arr);
root = 1;
while (m--)
{
scanf("%s", buf);
switch ((buf[2] + buf[1]) ^ *buf)
{
case 'G'^('E'+'T'):
scanf("%d%d", &x, &y);
printf("%d\n", qsum(x, y));
break;
case 'M'^('A'+'X'):
printf("%d\n", qgss());
break;
case 'R'^('E'+'V'):
scanf("%d%d", &x, &y);
reverse(x, y);
break;
case 'M'^('A'+'K'):
scanf("%d%d%d", &x, &y, &z);
update(x, y, z);
break;
case 'D'^('E'+'L'):
scanf("%d%d", &x, &y);
remove(x, y);
break;
case 'I'^('N'+'S'):
scanf("%d%d", &x, &y);
memset(arr, 0, sizeof arr);
for (int i = 1; i <= y; i++)
{
scanf("%d", arr+i);
}
insert(x, build(1, y, arr));
break;
}
}
return 0;
}
其中qgss()为求序列中最大子段和