测试地址:Sequence
**做法:**本题需要用到贪心+左偏树。
在讲做法之前先说几句无关的话…NOI考完之后内心一片空虚,于是在颓废了约十天之后终于鼓起勇气写代码了,实在是可喜可贺…
对于这一道题,题目要求的是递增序列,发现不太好求,于是根据要求的值的几何意义,如果我们把原数列做
a
i
=
a
i
−
i
a_i=a_i-i
ai=ai−i这样的变换,那么我们求出的不递减序列的最优解和原来的最优解相同。
序列不递减意味着可以相同。那么显然地,如果
a
a
a是递减序列,根据问题的几何意义,我们知道要求的序列
z
z
z中的每一项都是
a
a
a的中位数
w
w
w时是最优的;如果
a
a
a是不递减序列,那么要求的序列
z
z
z等于
a
a
a时是最优的。不难看出,我们可以把这种情况看成是,将序列的每一项分为一段,每一段的答案是该段内的中位数。
于是现在我们猜想:对于任意一个序列,最优答案是不是将序列分成若干段,其中每一段的答案都是该段内的中位数呢?
假设我们已经得到了序列
a
a
a和
b
b
b的最优解,它们的最优解都是序列的中位数。现在我们只要证明出将
a
,
b
a,b
a,b拼起来的序列可以被表示成几段中位数拼起来的形式,就可以归纳出上面的结论。令
a
,
b
a,b
a,b的中位数分别为
u
,
v
u,v
u,v,当
u
≤
v
u\le v
u≤v时,最优解完全不用变化。当
u
>
v
u>v
u>v时,根据一些证明(可以看论文),就可以知道
a
,
b
a,b
a,b拼起来后,最优解中的每一项都是拼起来后的序列的中位数。
因此我们得到了一种基于合并的算法:假设我们已经求出了
a
a
a的前
i
i
i项的答案,即分好了段,那么计算第
i
+
1
i+1
i+1项,首先把这一项单独划分为一段,然后对比当前段与前面的段的中位数,如果大于等于之前的中位数就不用管,否则就要把当前段与前面的段合并成一段,然后继续比较。容易发现可以用栈实现这个比较的过程,那么我们怎么快速的支持查询中位数以及合并呢?其实,因为没有删除元素的操作,我们只需要保存每一段最小的
⌈
n
2
⌉
\lceil\frac{n}{2}\rceil
⌈2n⌉个元素,其中的最大值就是我们所求的中位数了。这显然可以用堆维护。但我们还需要支持快速合并,因此就用可并堆,其中最好写的方法是左偏树。在合并之后堆中的元素数量可能超过
⌈
n
2
⌉
\lceil\frac{n}{2}\rceil
⌈2n⌉(不可能小于,仔细想一想就知道为什么了),那么就把顶上的元素弹出直到堆中的元素数量合法即可。
左偏树一次合并和弹出的时间复杂度是
O
(
log
n
)
O(\log n)
O(logn),而每个元素至多被弹出一次,合并也最多会执行
n
n
n次,所以总的时间复杂度是
O
(
n
log
n
)
O(n\log n)
O(nlogn),可以通过此题。
以下是本人代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
int n,st[1000010],top,l[1000010],r[1000010];
int siz[1000010],dis[1000010]={0},ch[1000010][2]={0};
ll a[1000010];
int merge(int x,int y)
{
if (!x) return y;
if (!y) return x;
if (a[x]<a[y]) swap(x,y);
ch[x][1]=merge(ch[x][1],y);
if (dis[ch[x][0]]<dis[ch[x][1]])
swap(ch[x][0],ch[x][1]);
dis[x]=dis[ch[x][1]]+1;
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
return x;
}
int Delete(int x)
{
return merge(ch[x][0],ch[x][1]);
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
a[i]-=i;
}
top=0;
dis[0]=-1;
for(int i=1;i<=n;i++)
{
siz[i]=1;dis[i]=0;
top++;
st[top]=l[top]=r[top]=i;
while(top>1&&a[st[top-1]]>a[st[top]])
{
st[top-1]=merge(st[top-1],st[top]);
r[top-1]=r[top];
top--;
while(siz[st[top]]>((r[top]-l[top]+1)>>1)+1)
st[top]=Delete(st[top]);
}
}
ll ans=0;
for(int i=1;i<=top;i++)
{
for(int j=l[i];j<=r[i];j++)
ans+=abs(a[j]-a[st[i]]);
}
printf("%lld",ans);
return 0;
}