splay模板

代码来自洛谷日报,感谢洛谷
本篇文章没有splay的教学,只有代码的使用说明。

直接贴上代码:
这个不带reverse的模板

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
const int maxn = 2e5+7;
int n, op, x;
int ch[maxn][2], fa[maxn], val[maxn], cnt[maxn], siz[maxn], ncnt, root;
bool chk(int x) 
{
    return ch[fa[x]][1]==x;
}
void pushup(int x) 
{
    siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}
void rotate(int x) 
{
    int y = fa[x], z = fa[y], k = chk(x), w = ch[x][k^1];
    ch[y][k] = w; 
	fa[w] = y;
    ch[z][chk(y)] = x;
	fa[x] = z;
    ch[x][k^1] = y; 
	fa[y] = x;
    pushup(y); 
	pushup(x);
}
void splay(int x, int goal = 0) 
{
    while(fa[x] != goal) 
	{
        int y = fa[x], z = fa[y];
        if (z != goal) 
		{
            if(chk(x) == chk(y)) rotate(y);
            else rotate(x);
        }
        rotate(x);
    }
    if(!goal) root = x;
}
void insert(int x) 
{
    int cur = root, p = 0;
    while(cur && val[cur] != x) 
	{
        p = cur;
        cur = ch[cur][x > val[cur]];
    }
    if(cur) cnt[cur]++;
	else 
	{
        cur = ++ncnt;
        if(p) ch[p][x > val[p]] = cur;
        ch[cur][0] = ch[cur][1] = 0;
        fa[cur] = p; 
		val[cur] = x;
        cnt[cur] = siz[cur] = 1;
    }
    splay(cur);
}
void find(int x) 
{
    int cur = root;
    while(ch[cur][x > val[cur]] && x != val[cur]) 
	{
        cur = ch[cur][x > val[cur]];
    }
    splay(cur);
}
int kth(int k) 
{
    int cur = root;
    while(true) 
	{
        if(ch[cur][0] && k <= siz[ch[cur][0]]) 
		{
            cur = ch[cur][0];
        } 
		else if(k > siz[ch[cur][0]] + cnt[cur]) 
		{
            k -= siz[ch[cur][0]] + cnt[cur];
            cur = ch[cur][1];
        } 
		else 
		{
            return cur;
        }
    }
}
int pre(int x) 
{
    find(x);
    if(val[root] < x) return root;
    int cur = ch[root][0];
    while(ch[cur][1]) cur = ch[cur][1];
    return cur;
}

int succ(int x) 
{
    find(x);
    if (val[root] > x) return root;
    int cur = ch[root][1];
    while(ch[cur][0]) cur = ch[cur][0];
    return cur;
}

void remove(int x) 
{
    int last = pre(x), next = succ(x);
    splay(last); 
	splay(next, last);
    int del = ch[next][0];
    if(cnt[del] > 1) 
	{
        cnt[del]--;
        splay(del);
    }
    else ch[next][0] = 0;
}
int main() 
{
    scanf("%d", &n);
    insert(INF);
    insert(-INF);
    while(n--) 
	{
        scanf("%d%d", &op, &x);
        switch (op) 
		{
            case 1: insert(x); break;
            case 2: remove(x); break;
            case 3: find(x); printf("%d\n", siz[ch[root][0]]); break;
            case 4: printf("%d\n", val[kth(x+1)]); break;
            case 5: printf("%d\n", val[pre(x)]); break;
            case 6: printf("%d\n", val[succ(x)]); break;
        }
    }
}
  • insert(x),插入x

  • remove(x),删除x(若有多个相同的,只删除一个)

  • find(x),将x伸展到根,则siz[ch[root][0]]就是x左子树的大小,即x的排名

  • kth(x+1),查询排名为x的数,val[kth(x+1)]为排名为x的数的值

  • pre(x),查询x的前驱,前驱定义为小于xx,且最大的数

  • succ(x),查询x的后继,后继定义为大于xx,且最小的数

这个是带reverse的模板:

#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5+7;
int ch[maxn][2], fa[maxn], val[maxn], cnt[maxn], siz[maxn], rev[maxn], root, ncnt;
int n, m, x, y;
bool chk(int x) 
{
    return ch[fa[x]][1] == x;
}
void pushup(int x) 
{
    siz[x] = siz[ch[x][0]] + siz[ch[x][1]] + cnt[x];
}
void pushdown(int x) 
{
    if(rev[x])  
	{
        swap(ch[x][0], ch[x][1]);
        rev[ch[x][0]] ^= 1;
        rev[ch[x][1]] ^= 1;
        rev[x] = 0;
    }
}
void rotate(int x) 
{
    int y = fa[x], z = fa[y], k = chk(x), w = ch[x][k^1];
    ch[y][k] = w; 
	fa[w] = y;
    ch[z][chk(y)] = x; 
	fa[x] = z;
    ch[x][k^1] = y; 
	fa[y] = x;
    pushup(y); 
	pushup(x); 
}
void splay(int x, int goal = 0) 
{
    while(fa[x] != goal) 
	{
        int y = fa[x], z = fa[y];
        if(z != goal) 
		{
            if(chk(x) == chk(y)) rotate(y);
            else rotate(x);
        }
        rotate(x);
    }
    if(!goal) root = x;
}
void insert(int x) 
{
    int cur = root, p = 0;
    while(cur && val[cur] != x) 
	{
        p = cur;
        cur = ch[cur][x > val[cur]];
    }
    if(cur) cnt[cur]++;
	else 
	{
        cur = ++ncnt;
        if(p) ch[p][x > val[p]] = cur;
        ch[cur][0] = ch[cur][1] = 0;
        fa[cur] = p; 
		val[cur] = x;
        cnt[cur] = siz[cur] = 1;
    }
    splay(cur);
}
void find(int x) 
{
    int cur = root;
    while(ch[cur][x > val[cur]] && val[cur] != x) 
	{
        cur = ch[cur][x > val[cur]];
    }
    splay(cur);
}
int kth(int k) 
{
    int cur = root;
    while (true) 
	{
        pushdown(cur);
        if (ch[cur][0] && k <= siz[ch[cur][0]]) 
		{
            cur = ch[cur][0];
        } 
		else if (k > siz[ch[cur][0]] + cnt[cur]) 
		{
            k -= siz[ch[cur][0]] + cnt[cur];
            cur = ch[cur][1];
        } 
		else 
		{
            return cur;
        }
    }
}
void reverse(int l, int r) 
{
    int x = kth(l), y = kth(r+2);
    splay(x); 
	splay(y, x);
    rev[ch[y][0]] ^= 1;
}
int pre(int x) 
{
    find(x);
    if(val[root] < x) return root;
    int cur = ch[root][0];
    while(ch[cur][1]) cur = ch[cur][1];
    return cur;
}
int succ(int x) 
{
    find(x);
    if(val[root] > x) return root;
    int cur = ch[root][1];
    while(ch[cur][0]) cur = ch[cur][0];
    return cur;
}
void output(int x) 
{
    pushdown(x);
    if(ch[x][0]) output(ch[x][0]);
    if(val[x] && val[x] <= n) printf("%d ", val[x]);
    if(ch[x][1]) output(ch[x][1]);
}
int main()
{
    scanf("%d%d", &n, &m);
    for (int i = 0; i <= n+1; i++) insert(i);
    while (m--) 
	{
        scanf("%d%d", &x, &y);
        reverse(x, y);
    }
    output(root);
    return 0;
}
  • reverse(l,r),区间[l,r]翻转
  • output(x),从根x开始中序遍历,输出splay tree上的所有数

下面这个是splay维护区间树:

#include <bits/stdc++.h>
#define INF 0x3f3f3f3f
using namespace std;
const int maxn = 1e6+7;
int n, m, arr[maxn], c, x, y, z;
char buf[32];
int siz[maxn], sum[maxn], upd[maxn], rev[maxn], la[maxn], ra[maxn], gss[maxn];
int val[maxn], ch[maxn][2], fa[maxn], ncnt, root;
queue<int> q;
void recycle(int x) 
{
    if (ch[x][0]) recycle(ch[x][0]);
    if (ch[x][1]) recycle(ch[x][1]);
    q.push(x);
}
inline int newNode(int x) 
{
    int cur;
    if (q.empty()) cur = ++ncnt;
    else cur = q.front(), q.pop();
    ch[cur][0] = ch[cur][1] = fa[cur] = 0;
    val[cur] = sum[cur] = gss[cur] = x;
    la[cur] = ra[cur] = max(0, x);
    upd[cur] = rev[cur] = 0;
    siz[cur] = 1;
    return cur;
}
inline bool chk(int x) 
{
    return ch[fa[x]][1] == x;
}
inline void pushup(int x) 
{
    int l = ch[x][0], r = ch[x][1];
    siz[x] = siz[l] + siz[r] + 1;
    sum[x] = sum[l] + sum[r] + val[x];
    // 这里和线段树不同,线段树只有叶子上有权值,平衡树上所有点都有,必须+val[x] 
    la[x] = max(la[l], sum[l] + val[x] + la[r]);
    ra[x] = max(ra[r], sum[r] + val[x] + ra[l]);
    gss[x] = max(ra[l] + val[x] + la[r], max(gss[l], gss[r]));
}

inline void rotate(int x) 
{
    int y = fa[x], z = fa[y], k = chk(x), w = ch[x][k^1];
    ch[y][k] = w; 
	fa[w] = y;
    ch[z][chk(y)] = x; 
	fa[x] = z;
    ch[x][k^1] = y; 
	fa[y] = x;
    pushup(y); 
	pushup(x);
}
inline void pushdown(int x) 
{
    int l = ch[x][0], r = ch[x][1];
    if(upd[x]) 
	{
        upd[x] = rev[x] = 0;
        if(l) 
		{
            upd[l] = 1; 
			val[l] = val[x];
            sum[l] = val[x] * siz[l];
            la[l] = ra[l] = max(sum[l], 0);
            gss[l] = val[x] < 0 ? val[x] : sum[l];
        }
        if(r) 
		{
            upd[r] = 1; 
			val[r] = val[x];
            sum[r] = val[x] * siz[r];
            la[r] = ra[r] = max(sum[r], 0);
            gss[r] = val[x] < 0 ? val[x] : sum[r];
        }
    }
    if(rev[x]) 
	{
        rev[l] ^= 1; 
		rev[r] ^= 1; 
		rev[x] = 0;
        swap(la[l], ra[l]); 
		swap(la[r], ra[r]);
        swap(ch[l][0], ch[l][1]);
        swap(ch[r][0], ch[r][1]);
    }
}
inline void splay(int x, int goal = 0) 
{
    while(fa[x] != goal) {
        int y = fa[x], z = fa[y];
        if (z != goal) 
		{
            if (chk(x) == chk(y)) rotate(y);
            else rotate(x);
        }
        rotate(x);
    }
    if (!goal) root = x;
}
int build(int l, int r, int *arr) 
{
    if (l > r) return 0;
    int mid = (l+r)>>1, cur = newNode(arr[mid]);
    if (l == r) return cur;
    if ((ch[cur][0] = build(l, mid-1, arr))) fa[ch[cur][0]] = cur;
    if ((ch[cur][1] = build(mid+1, r, arr))) fa[ch[cur][1]] = cur;
    pushup(cur);
    return cur;
}
inline int kth(int k) 
{
    int cur = root;
    while(true) 
	{
        pushdown(cur);
        if(ch[cur][0] && k <= siz[ch[cur][0]]) 
		{
            cur = ch[cur][0];
        } 
		else if (k > siz[ch[cur][0]] + 1) {
            k -= siz[ch[cur][0]] + 1;
            cur = ch[cur][1];
        } 
		else 
		{
            return cur;
        }
    }
}
inline void insert(int x, int y) 
{
    int u = kth(x+1), v = kth(x+2);
    splay(u); 
	splay(v, u);
    ch[v][0] = y; 
	fa[y] = v;
    pushup(v); 
	pushup(u); 
}
inline int qsum(int x, int y) 
{
    int u = kth(x), v = kth(x+y+1);
    splay(u); 
	splay(v, u);
    return sum[ch[v][0]];
}

inline int qgss() 
{
    return gss[root];
}
inline void remove(int x, int y) 
{
    int u = kth(x), v = kth(x+y+1);
    splay(u); 
	splay(v, u);
    recycle(ch[v][0]);
    ch[v][0] = 0;
    pushup(v); 
	pushup(u);
}

inline void reverse(int x, int y) 
{
    int u = kth(x), v = kth(x+y+1);
    splay(u); 
	splay(v, u);
    int w = ch[v][0];
    if (!upd[w]) 
	{
        rev[w] ^= 1;
        swap(ch[w][0], ch[w][1]);
        swap(la[w], ra[w]);
        pushup(v); 
		pushup(u);
    }
}
inline void update(int x, int y, int z) 
{
    int u = kth(x), v = kth(x+y+1);
    splay(u); 
	splay(v, u);
    int w = ch[v][0];
    upd[w] = 1; 
	val[w] = z; 
	sum[w] = siz[w] * z;
    la[w] = ra[w] = max(0, sum[w]);
    gss[w] = z < 0 ? z : sum[w];
    pushup(v); 
	pushup(u);
}
int main() {
    scanf("%d%d", &n, &m);
    for (int i = 2; i <= n+1; i++)
	{
        scanf("%d", arr+i);
    }
    gss[0] = val[0] = -INF;
    arr[1] = arr[n += 2] = -INF;
    build(1, n, arr); 
	root = 1;
    while (m--) 
	{
        scanf("%s", buf);
        switch ((buf[2] + buf[1]) ^ *buf) 
		{
            case 'G'^('E'+'T'):
                scanf("%d%d", &x, &y);
                printf("%d\n", qsum(x, y));
                break;
            case 'M'^('A'+'X'):
                printf("%d\n", qgss());
                break;
            case 'R'^('E'+'V'):
                scanf("%d%d", &x, &y);
                reverse(x, y);
                break;
            case 'M'^('A'+'K'):
                scanf("%d%d%d", &x, &y, &z);
                update(x, y, z);
                break;
            case 'D'^('E'+'L'):
                scanf("%d%d", &x, &y);
                remove(x, y);
                break;
            case 'I'^('N'+'S'):
                scanf("%d%d", &x, &y);
                memset(arr, 0, sizeof arr);
                for (int i = 1; i <= y; i++) 
                {
                    scanf("%d", arr+i);
                }
                insert(x, build(1, y, arr));
                break;
        }
    }
    return 0;
}

其中qgss()为求序列中最大子段和

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值