题意
给\(n(n \le 10^6)\)个数的序列\(a\),求一个递增序列\(b\)使得\(\sum_{i=1}^{n} |a_i-b_i|\)最小。
分析
神题啊不会。
具体证明看黄源河论文《左偏树的特点及其应用》
思路:
- 将问题转化为求一个不降序列\(b\)。
- 如果\(a_1 \le a_2 \le ... \le a_n\),则最优解显然是\(b_i=a_i\)
- 如果\(a_1 \ge a_2 \ge ... \ge a_n\),则最优解显然是\(b_i=w\),其中\(w\)是\(a\)序列的中位数\(a_{\left \lfloor \frac{n+1}{2} \right \rfloor}\)
如果有两段\(a_1, a_2, ..., a_n\)和\(a_{n+1}, a_{n+2}, ..., a_{m}\),其中左边的最优解是\(b_i=u(1 \le i \le n)\),右边的最优解是\(b_i=w(n < i \le m)\),则 - 当\(u \le w\),则最优解显然是\(b_i=u(1 \le i \le n), b_i=w(n < i \le m)\)
- 当\(u > w\)时,则最优解是\(b_i=x(1 \le i \le m)\),其中\(x\)是\(a\)序列的中位数\(a_{\left \lfloor \frac{m+1}{2} \right \rfloor}\),证明如下:
首先我们来证明,对于任意序列\(a\),如果最优解是\(b_i=w(1 \le i \le n)\),其中\(w\)是中位数,那么对于所有的\(w \le w ' \le c_1\)或\(c_n \le w ' \le w\),解\(b_i=c_i(1 \le i \le n)\)都不会比解\(b_i=w ' (1 \le i \le n)\)更优。
然后我并没有看懂那个归纳证明QAQ
就是这句:
因为如果解变坏了,由归纳假设可知a[2],a[3],...,a[n]的中位数w>u,这样的话,最优解就应该为(u, u, ... , u, w, w, ... ,w ),矛盾。
现在回到\(2\),显然最优解中\(b_n \le u, b_{n+1} \ge w\),然后根据刚刚我们证明的东西,则最优解中肯定\(b_i=b_n(1 \le i < n), b_i=b_{n+1}(n < i \le m)\)。
也就是说,给你\(m\)个点要求找两个值\(u \le w\),使得前\(n\)个点到\(u\)的距离和加上剩下的点到\(w\)的距离和最短。显然一组最优解是\(u=w=a_{\left \lfloor \frac{m+1}{2} \right \rfloor}\)
至于思路\(1\)中怎么转化问题,就很简单了:
\(\sum_{i=1}^{n} |a_i-b_i| = \sum_{i=1}^{n} |(a_i-i)-(b_i-i)|\)
则令新的\(a ' _ i = a_i - i\)就行了。
至于思路\(2\)中怎么讨论,一个长度为\(n\)的不降序列可以看做\(n\)个不升序列。
题解
所以我们从左到右合并,如果新加进来的数和前面的数不构成不升序列,则合并相邻的。
于是问题转化为如何维护中位数。
然后发现对于正整数\(n, m\)有\(\left \lfloor \frac{n+1}{2} \right \rfloor + \left \lfloor \frac{m+1}{2} \right \rfloor \ge \left \lfloor \frac{n+m+1}{2} \right \rfloor\),所以我们只需要维护两个区间的中位数及比中位数小的数即可。然后合并的时候再考虑删掉一些数即可。
所以删除最大的数、合并两个东西这个活交给左偏树。
#include <bits/stdc++.h>
using namespace std;
inline int getint() {
int x=0;
char c=getchar();
for(; c<'0'||c>'9'; c=getchar());
for(; c>='0'&&c<='9'; x=x*10+c-'0', c=getchar());
return x;
}
const int N=1000105;
int q[N], s[N], t[N], n;
struct node *null;
struct node {
node *c[2];
int d, w;
void init(int _w) {
c[0]=c[1]=null;
w=_w;
d=0;
}
void up() {
if(c[0]->d<c[1]->d) {
swap(c[0], c[1]);
}
d=c[1]->d+1;
}
}Po[N], *iT=Po, *root[N];
inline node *newnode(int w) {
iT->init(w);
return iT++;
}
inline node *merge(node *l, node *r) {
if(l==null || r==null) {
return l==null?r:l;
}
if(l->w<r->w) {
swap(l, r);
}
l->c[1]=merge(l->c[1], r);
l->up();
return l;
}
int main() {
null=newnode(-(~0u>>1));
null->c[0]=null->c[1]=null;
n=getint();
int top=0;
for(int i=1; i<=n; ++i, ++top) {
root[top+1]=newnode(t[i]=getint()-i);
q[top+1]=i;
s[top+1]=1;
for(; top && root[top]->w>root[top+1]->w; --top) {
s[top]+=s[top+1];
root[top]=merge(root[top], root[top+1]);
for(int mid=(i-q[top]+2)>>1; s[top]>mid; --s[top], root[top]=merge(root[top]->c[0], root[top]->c[1]));
}
}
long long ans=0;
q[top+1]=0;
top=0;
for(int i=1; i<=n; ++i) {
for(; q[top+1]==i; ++top);
ans+=abs(t[i]-root[top]->w);
}
printf("%lld\n", ans);
return 0;
}