1500: [NOI2005]维修数列
Time Limit: 10 Sec Memory Limit: 64 MBSubmit: 14170 Solved: 4575
[ Submit][ Status][ Discuss]
Description
Input
输入的第1 行包含两个数N 和M(M ≤20 000),N 表示初始时数列中数的个数,M表示要进行的操作数目。
第2行包含N个数字,描述初始时的数列。
以下M行,每行一条命令,格式参见问题描述中的表格。
任何时刻数列中最多含有500 000个数,数列中任何一个数字均在[-1 000, 1 000]内。
插入的数字总数不超过4 000 000个,输入文件大小不超过20MBytes。
Output
对于输入数据中的GET-SUM和MAX-SUM操作,向输出文件依次打印结果,每个答案(数字)占一行。
Sample Input
2 -6 3 5 1 -5 -3 6 3
GET-SUM 5 4
MAX-SUM
INSERT 8 3 -5 7 2
DELETE 12 1
MAKE-SAME 3 3 2
REVERSE 3 6
GET-SUM 5 4
MAX-SUM
Sample Output
10
1
10
HINT
Source
对于一个刚学splay的蒟蒻来说有点困难啊。。。
于是改了好几天。
讲一些比较重要的东西:
每个操作需要将splay转成一个结构:把右端点右边后继ro旋转到根,再把左端点前驱lo旋转到ro的左儿子,
这样,目标区间即为lo的右儿子
为了方便判断边界的问题,需要加入两个虚拟节点,即左边界与右边界
然后后面的操作就比较好执行了
讲一下max-sum:
需要维护lsum,rsum,maxs,分别表示该区间内从左开始向右数的最大连续和,从右往左数的最大连续和以及该区间内的最大子段和
lsum的维护要分成三种情况,一种取值为lsum[lc],一种取值为sum[lc]+v[o],另一种取值为sum[lc]+v[当前节点]+lsum[rc],三种情况取一个max值,rsum与lsum同理。
而maxs大体上分两种讨论:
一种为最大子段和跨越了当前节点o,一种没有;
后者相对容易,只需在maxs[lc],maxs[rc]取一个max值就好
而前者就为lsum[rc]+v[当前节点]+rsum[lc]
两种情况取一个max
几个蒟蒻的WA点:
1、点0的最大子段和要初始化为-INF
2、注意到当覆盖时,若覆盖的权值为负数,则lsum与rum需设置为0(而不是setv[o]),而maxs需为setv[o]
人生中第二个splay,收获颇丰
代码:
#include<cstring>
#include<iostream>
#include<cstdlib>
#include<cstdio>
#include<algorithm>
#include<stack>
using namespace std;
const int maxn = 500100;
const int INF = 2E9;
int q[500000];
char s[40];
bool flag[maxn],setf[maxn];
int a[maxn];
int n,m,cur,rt,st,ed,tot = 0,fa[maxn],ch[maxn][2],v[maxn],top;
int setv[maxn],lsum[maxn],rsum[maxn],maxs[maxn],sum[maxn],siz[maxn]; //附加信息
inline void swap(int& x,int& y)
{
int tmp = x;
x = y;
y = tmp;
}
inline void maintain(int o)
{
siz[o] = siz[ch[o][1]] + siz[ch[o][0]] + 1;
sum[o] = sum[ch[o][1]] + sum[ch[o][0]] + v[o];
lsum[o] = max(lsum[ch[o][0]],sum[ch[o][0]] + v[o] + lsum[ch[o][1]]);
rsum[o] = max(rsum[ch[o][1]],sum[ch[o][1]] + v[o] + rsum[ch[o][0]]);
maxs[o] = max(max(maxs[ch[o][1]],maxs[ch[o][0]]),lsum[ch[o][1]] + v[o] + rsum[ch[o][0]]);
}
inline void cover(int o,int x)
{
if (o == st || o == ed) return;
setf[o] = 1;
setv[o] = v[o] = x;
sum[o] = siz[o] * setv[o];
lsum[o] = rsum[o] = max(0,sum[o]);
maxs[o] = max(setv[o],sum[o]);
}
inline void reverse(int o)
{
flag[o] ^= 1;
swap(ch[o][1],ch[o][0]);
swap(lsum[o],rsum[o]);
}
inline void pushdown(int o)
{
if (flag[o])
{
flag[o] = 0;
if (ch[o][0]) reverse(ch[o][0]);
if (ch[o][1]) reverse(ch[o][1]);
}
if (setf[o])
{
setf[o] = 0;
if (ch[o][0]) cover(ch[o][0],setv[o]);
if (ch[o][1]) cover(ch[o][1],setv[o]);
setv[o] = 0;
}
}
inline void output(int o)
{
if (!o) return;
pushdown(o);
output(ch[o][0]);
// printf("%d ",v[o]);
output(ch[o][1]);
}
inline int check(int o)
{
if (!o) return 0;
if (!v[o]) return o;
return check(ch[o][0]);
return check(ch[o][1]);
}
inline void rotate(int x)
{
int o = fa[x];
int y = fa[o];
int d = (ch[o][1] == x ? 0 : 1);
ch[o][d ^ 1] = ch[x][d]; maintain(o);
if (ch[x][d]) fa[ch[x][d]] = o;
ch[x][d] = o; maintain(x);
fa[o] = x; fa[x] = y;
if (y) {ch[y][ch[y][1] == o] = x; maintain(y);}
}
inline int kth(int o,int k)
{
if (!o) return 0;
pushdown(o);
if (k == siz[ch[o][0]] + 1) return o;
else if (k <= siz[ch[o][0]]) return kth(ch[o][0],k);
else return kth(ch[o][1],k - siz[ch[o][0]] - 1);
}
inline void pushdown_root(int x)
{
if (fa[x]) pushdown_root(fa[x]);
pushdown(x);
}
inline void splay(int x)
{
if (!x) return;
pushdown_root(x);
for (int y = fa[x]; y; rotate(x) , y = fa[x])
if (fa[y]) rotate((ch[y][1] == x) ^ (ch[fa[y]][1] == y) ? x : y);
rt = x;
maintain(rt);
}
inline void rever(int l,int r)
{
int lo = kth(rt,l),ro = kth(rt,r + 2);
splay(ro);
pushdown(ro);
fa[ch[ro][0]] = 0;
splay(lo);
pushdown(lo);
rt = ro;
fa[lo] = ro;
ch[ro][0] = lo;
maintain(ro);
reverse(ch[lo][1]);
}
inline int qsum(int l,int r)
{
int lo = kth(rt,l),ro = kth(rt,r + 2);
splay(ro);
fa[ch[ro][0]] = 0;
splay(lo);
rt = ro;
fa[lo] = ro;
ch[ro][0] = lo;
maintain(ro);
return sum[ch[lo][1]];
}
inline int maxsum()
{
splay(ed);
fa[ch[ed][0]] = 0;
splay(st);
rt = ed;
fa[st] = ed;
ch[ed][0] = st;
maintain(ed);
return maxs[ch[st][1]];
}
inline int build(int l,int r,int f)
{
if (r < l) return 0;
int mid = l + r >> 1;
int x;
if (!top)
x = ++tot;
else
{
x = q[top];
q[top--] = 0;
}
v[x] = a[mid]; fa[x] = f;
ch[x][0] = build(l,mid - 1,x);
ch[x][1] = build(mid + 1,r,x);
maintain(x);
return x;
}
inline void insert(int pos,int toti)
{
int r = build(1,toti,0),lo = kth(rt,pos + 1),ro = kth(rt,pos + 2);
splay(ro);
pushdown(lo);
fa[ch[ro][0]] = 0;
splay(lo);
rt = ro;
ch[ro][0] = lo;
fa[lo] = ro;
pushdown(lo);
ch[lo][1] = r;
fa[r] = lo;
maintain(lo);
maintain(ro);
}
inline void del(int o)
{
if (!o) return;
q[++top] = o;
setv[o] = setf[o] = flag[o] = 0;
v[o] = lsum[o] = sum[o] = rsum[o] = maxs[o] = siz[o] = 0;
del(ch[o][1]);
del(ch[o][0]);
ch[o][1] = ch[o][0] = fa[o] = 0;
}
inline void remove(int l,int r)
{
int lo = kth(rt,l),ro = kth(rt,r + 2);
splay(ro);
pushdown(ro);
fa[ch[ro][0]] = 0;
splay(lo);
pushdown(lo);
rt = ro;
fa[lo] = ro;
ch[ro][0] = lo;
del(ch[lo][1]);
ch[lo][1] = 0;
maintain(lo);
maintain(ro);
}
inline void update(int l,int r,int c)
{
int lo = kth(rt,l),ro = kth(rt,r + 2);
splay(ro);
pushdown(ro);
fa[ch[ro][0]] = 0;
splay(lo);
pushdown(lo);
rt = ro;
fa[lo] = ro;
ch[ro][0] = lo;
maintain(ro);
cover(ch[lo][1],c);
}
inline int getint()
{
char c = getchar();
int ret = 0,f = 1;
while (c < '0' || c > '9')
{
if (c == '-') f = -1;
c = getchar();
}
while (c >= '0' && c <= '9')
ret = ret * 10 + c - '0',c = getchar();
return ret * f;
}
int main()
{
n = getint(); m = getint();
for (int i = 1; i <= n; i++)
a[i] = getint();
st = ++tot;
siz[st] = 1;
rt = build(1,n + 1,0);
int first = kth(rt,1);
splay(first);
ch[first][0] = 1;
fa[1] = first;
maintain(rt);
ed = tot;
maxs[0] = -INF;
for (int j = 1; j <= m; j++)
{
scanf("%s",s);
if (j == 228)
int k = 200;
if (!strcmp(s,"INSERT"))
{
int pos = getint(),x = getint();
for (int i = 1; i <= x; i++)
a[i] = getint();
insert(pos,x);
}
if (!strcmp(s,"DELETE"))
{
int pos = getint(),x = getint();
remove(pos,pos + x - 1);
}
if (!strcmp(s,"MAKE-SAME"))
{
int pos = getint(),x = getint(),c = getint();
update(pos,pos + x - 1,c);
}
if (!strcmp(s,"REVERSE"))
{
int pos = getint(),x = getint();
rever(pos,pos + x - 1);
}
if (!strcmp(s,"GET-SUM"))
{
int pos = getint(),x = getint();
printf("%d\n",qsum(pos,pos + x - 1));
}
if (!strcmp(s,"MAX-SUM"))
printf("%d\n",maxsum());
}
return 0;
}