今天所给的这些常数优化,可以使我原来所写的 splay 变快至少 1.5 倍。
1.在 rotate 过程中不需要下放标记。显然,在找到点之前,就一定会有自上而下的过程。在这里下放标记即可。
2.在 rotate 过程中不 update 当前节点。显然,只需要在 splay 到根之后 update 即可。
3.尽量使用自己 define 的 max,swap 之类的为所用很多的函数,大数据下稍有改变即可快 0.2 秒。
4.插入之后 splay 所插入的最低深度的点。如果将一条链 splay 成四分之一以下的深度,常数将巨大优化。
通过以上 4 个优化,我将 3.8 秒的大数据(sequence 维护数列)压到了 2 秒以下。
Code :
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <climits>
#include <iostream>
#include <algorithm>
using namespace std;
#define inf 1005
#define maxn 600005
#define MAX(a, b) ({int _ = a, __ = b; _ > __ ? _ : __;})
#define Swap(p, q) ({node * o = p; p = q; q = o;})
#define swaP(a, b) ({int _ = a; a = b; b = _;})
struct node{
node * pt, * ch[2];
int size, data, sum, ls, rs, ms;
bool rev, alt;
} null[maxn], * ve[maxn], ** total, * root, * head, * tail;
node * newnode(node * fa, int da)
{
node * p = * (total --);
if (p->ch[0] && p->ch[0] != null) * (++ total) = p->ch[0];
if (p->ch[1] && p->ch[1] != null) * (++ total) = p->ch[1];
p->pt = fa, p->ch[0] = p->ch[1] = null;
p->size = 1, p->data = p->sum = p->ls = p->rs = p->ms = da;
return p->rev = p->alt = 0, p;
}
node * pushdown(node * p)
{
if (p == null) return null;
if (p->alt)
{
p->ch[0]->alt = 1, p->ch[0]->data = p->data;
p->ch[1]->alt = 1, p->ch[1]->data = p->data;
p->sum = p->size * p->data;
p->ls = p->rs = p->ms = (p->data < 0 ? p->data : p->sum), p->alt = 0;
}
if (p->rev)
{
p->ch[0]->rev = ! p->ch[0]->rev;
p->ch[1]->rev = ! p->ch[1]->rev;
Swap(p->ch[0], p->ch[1]), swaP(p->ls, p->rs), p->rev = 0;
}
return p;
}
node * update(node * p)
{
if (p == null) return null;
pushdown(p->ch[0]), pushdown(p->ch[1]);
p->size = p->ch[0]->size + 1 + p->ch[1]->size;
p->sum = p->ch[0]->sum + p->data + p->ch[1]->sum;
p->ls = MAX(p->ch[0]->ls, p->ch[0]->sum + p->data);
p->ls = MAX(p->ls, p->ch[0]->sum + p->data + p->ch[1]->ls);
p->rs = MAX(p->ch[1]->rs, p->ch[1]->sum + p->data);
p->rs = MAX(p->rs, p->ch[1]->sum + p->data + p->ch[0]->rs);
p->ms = MAX(p->data, max(p->ch[0]->ms, p->ch[1]->ms));
p->ms = MAX(p->ms, p->ch[0]->rs + p->data + p->ch[1]->ls);
p->ms = MAX(p->ms, max(p->ch[0]->rs, p->ch[1]->ls) + p->data);
return p;
}
void rotate(node * p)
{
node * q = p->pt;
bool flag = p == q->ch[0];
p->ch[flag]->pt = q, q->ch[! flag] = p->ch[flag], p->ch[flag] = q;
q->pt->ch[q == q->pt->ch[1]] = p, p->pt = q->pt, update(q)->pt = p;
}
node * splay(node * p, node * tar)
{
while (p->pt != tar)
if (p->pt->pt == tar) rotate(p);
else if ((p == p->pt->ch[0]) xor (p->pt == p->pt->pt->ch[0])) rotate(p), rotate(p);
else rotate(p->pt), rotate(p);
return update(p);
}
node * getpos(int rank)
{
node * p = pushdown(root);
int now = p->ch[0]->size + 1;
while (now != rank)
if (now < rank)
p = pushdown(p->ch[1]), now += p->ch[0]->size + 1;
else
p = pushdown(p->ch[0]), now -= p->ch[1]->size + 1;
return p;
}
node * getseg(int l, int r)
{
root = splay(getpos(l), null);
splay(getpos(r + 2), root);
return update(root);
}
int n, m, pos, tot, c, a[maxn];
char str[inf];
void prepare()
{
scanf("%d%d", & n, & m), * ve = null;
for (int i = 1; i <= n; ++ i)
scanf("%d", & a[i]), ve[i] = null + i;
for (int i = n + 1; i < maxn; ++ i)
ve[i] = null + i; total = ve + maxn - 1;
null->ls = null->rs = null->ms = - inf;
root = head = newnode(null, 0), root->size = 2;
root->ch[1] = tail = newnode(root, 0);
}
int main()
{
freopen("sequence.in", "r", stdin);
freopen("sequence.out", "w", stdout);
prepare();
pos = 0, tot = n;
root = splay(getpos(pos + 1), null);
node * p = newnode(root, a[tot]);
p->ch[1] = root->ch[1], p->ch[1]->pt = p, root->ch[1] = p;
while (-- tot) p = p->ch[0] = newnode(p, a[tot]);
root = splay(p, null);
while (scanf("%s", str), m --)
if (str[0] == 'I')
{
scanf("%d%d", & pos, & tot);
for (int i = 1; i <= tot; ++ i)
scanf("%d", & a[i]);
root = splay(getpos(pos + 1), null);
node * p = newnode(root, a[tot]);
p->ch[1] = root->ch[1], p->ch[1]->pt = p, root->ch[1] = p;
while (-- tot) p = p->ch[0] = newnode(p, a[tot]);
root = splay(p, null);
}
else if (str[0] == 'D')
{
scanf("%d%d", & pos, & tot);
node * p = getseg(pos, pos + tot - 1);
* (++ total) = p->ch[1]->ch[0], p->ch[1]->ch[0] = null;
update(root->ch[1]), update(root);
}
else if (str[0] == 'M' and str[2] == 'K')
{
scanf("%d%d%d", & pos, & tot, & c);
node * p = getseg(pos, pos + tot - 1);
p->ch[1]->ch[0]->alt = 1, p->ch[1]->ch[0]->data = c;
pushdown(root->ch[1]->ch[0]), update(root->ch[0]), update(root);
}
else if (str[0] == 'R')
{
scanf("%d%d", & pos, & tot);
node * p = getseg(pos, pos + tot - 1);
p->ch[1]->ch[0]->rev = ! p->ch[1]->ch[0]->rev;
pushdown(root->ch[1]->ch[0]), update(root->ch[0]), update(root);
}
else if (str[0] == 'G')
{
scanf("%d%d", & pos, & tot);
node * p = getseg(pos, pos + tot - 1);
printf("%d\n", p->ch[1]->ch[0]->sum);
}
else
{
root = splay(head, null), splay(tail, root);
printf("%d\n", root->ch[1]->ch[0]->ms);
}
return 0;
}