做法
枚举每长度为 k k 的段寻找中位数即可。splay维护。
代码
=> 主要是想说这一点,由于计算的必要,相同的数不能合并到一个节点,否则之后调用 sum[ch[x][0]]/sum[ch[x][1]]
的时候会漏算和节点 相同的数。
#include<bits/stdc++.h>
#define rep(i,x,y) for (int i=(x); i<=(y); i++)
#define ll long long
#define ld long double
#define inf 1000000000
#define INF 1000000000000000000ll
using namespace std;
ll read(){
char ch=getchar(); ll x=0; int op=1;
for (; !isdigit(ch); ch=getchar()) if (ch=='-') op=-1;
for (; isdigit(ch); ch=getchar()) x=(x<<1)+(x<<3)+ch-'0';
return x*op;
}
#define N 100005
int n,m,a[N],data[N],fa[N],ch[N][2],siz[N],rt,tot; ll ans,sum[N];
void up(int x){
sum[x]=sum[ch[x][0]]+sum[ch[x][1]]+data[x];
siz[x]=siz[ch[x][0]]+siz[ch[x][1]]+1;
}
void rot(int x){
int y=fa[x],z=fa[y],f=ch[y][1]==x;
ch[y][f]=ch[x][f^1]; if (ch[x][f^1]) fa[ch[x][f^1]]=y;
fa[x]=z; if (z) ch[z][ch[z][1]==y]=x;
fa[y]=x; ch[x][f^1]=y; up(y); up(x);
}
void splay(int x,int tp){
while (fa[x]!=tp){
int y=fa[x],z=fa[y];
if (z!=tp) rot((ch[z][0]==y)==(ch[y][0]==x)?y:x);
rot(x);
}
if (!tp) rt=x;
}
void insert(int val){
int x=rt;
if (!rt){
rt=x=++tot;
ch[x][0]=ch[x][1]=fa[x]=0;
data[x]=sum[x]=val; siz[x]=1;
return;
}
while (x){
int &y=ch[x][val>data[x]];
if (!y){
y=++tot;
ch[y][0]=ch[y][1]=0; fa[y]=x;
data[y]=sum[y]=val; siz[y]=1;
x=y; break;
}
x=y;
}
splay(x,0);
}
int find(int val){
int x=rt;
while (ch[x][val>data[x]] && data[x]!=val) x=ch[x][val>data[x]];
splay(x,0); return x;
}
int getpre(int val){
int x=find(val); if (data[x]<val) return x;
x=ch[x][0];
while (ch[x][1]) x=ch[x][1];
return x;
}
int getnxt(int val){
int x=find(val); if (data[x]>val) return x;
x=ch[x][1];
while (ch[x][0]) x=ch[x][0];
return x;
}
void delet(int val){
int x=getpre(val),y=getnxt(val);
splay(x,0); splay(y,x);
int &z=ch[y][0];
data[z]=sum[z]=siz[z]=0,z=0,splay(y,0);
}
int getkth(int k){
int x=rt;
while (1){
if (k<=siz[ch[x][0]]) x=ch[x][0];
else if (k>siz[ch[x][0]]+1) k-=siz[ch[x][0]]+1,x=ch[x][1];
else return x;
}
}
int main(){
n=read(),m=read();
rep (i,1,n) a[i]=read();
insert(-inf); insert(inf);
rep (i,1,m-1) insert(a[i]);
ans=INF;
rep (i,m,n){
insert(a[i]);
int tmp=getkth((m+1)/2+1);//返回节点编号
splay(tmp,0);
ans=min(ans,(ll)data[tmp]*((m+1)/2-1)-(sum[ch[tmp][0]]+inf)+(sum[ch[tmp][1]]-inf)-(ll)data[tmp]*(m-(m+1)/2));
delet(a[i-m+1]);
}
cout<<ans<<endl;
return 0;
}