对于长度为
k
k
的区间,取的一定是他们的中位数(第大的数)
现在要不断挪动这个区间(添加,删除),那么对于这三个操作,一棵平衡树就可以实现啦!
代码如下:
#include<algorithm>
#include<ctype.h>
#include<cstdio>
#define N 100050
using namespace std;
inline int read(){
int x=0,f=1;char c;
do c=getchar(),f=c=='-'?-1:f; while(!isdigit(c));
do x=(x<<3)+(x<<1)+c-'0',c=getchar(); while(isdigit(c));
return x*f;
}
struct Node{
Node *ch[2],*fa;
int x,siz,cnt;
long long sum;
inline void maintain(){
siz=cnt+ch[0]->siz+ch[1]->siz;
sum=cnt*x+ch[0]->sum+ch[1]->sum;
return;
}
inline int cmp(int k){
if(k==x) return -1;
return k<x?0:1;
}
inline int dir(){
if(fa->ch[0]==this) return 0;
if(fa->ch[1]==this) return 1;
return -1;
}
Node(int);
}*null,*root,*tmp;
Node::Node(int _):x(_),sum(_){
siz=cnt=1;
ch[0]=ch[1]=fa=null;
}
inline void init(){
null=new Node(-1);
null->fa=null->ch[0]=null->ch[1]=null;
null->siz=null->cnt=null->sum=0;
root=null;
}
void Insert(int k,Node *&x,Node *fa){
if(x==null){
x=new Node(k);
tmp=x;
x->fa=fa;
return;
}
int d=x->cmp(k);
if(!~d){
x->cnt++,x->siz++;
x->sum+=k;
tmp=x;
}
else Insert(k,x->ch[d],x);
x->maintain();
return;
}
inline void Rotate(Node *x,int d){
Node *k=x->ch[d^1];
x->ch[d^1]=k->ch[d];
if(x->ch[d^1]!=null) x->ch[d^1]->fa=x;
k->ch[d]=x;
if(x->fa!=null) x->fa->ch[x->dir()]=k;
k->fa=x->fa;x->fa=k;
x->maintain();k->maintain();
return;
}
inline void Splay(Node *x,Node *y){
while(x->fa!=y){
if(x->fa->fa!=y && x->dir()==x->fa->dir())
Rotate(x->fa->fa,x->dir()^1);
Rotate(x->fa,x->dir()^1);
}
if(y==null) root=x;
return;
}
inline void AddNew(int x){
Insert(x,root,null);
Splay(tmp,null);
return;
}
Node *LowerPointer(int k,Node *x){
if(x==null) return null;
if(x->x>=k) return LowerPointer(k,x->ch[0]);
Node *tmp=LowerPointer(k,x->ch[1]);
return tmp==null?x:tmp;
}
Node *UpperPointer(int k,Node *x){
if(x==null) return null;
if(x->x<=k) return UpperPointer(k,x->ch[1]);
Node *tmp=UpperPointer(k,x->ch[0]);
return tmp==null?x:tmp;
}
inline void Delete(int k){
Node *a=LowerPointer(k,root),*b=UpperPointer(k,root);
if(a==null && b==null){
root->cnt--;root->siz--;
root->sum-=root->x;
return;
}
if(a==null){
Splay(b,null);
if(root->ch[0]->cnt>1){
root->ch[0]->cnt--;root->ch[0]->siz--;
root->ch[0]->sum-=root->ch[0]->x;
}
else root->ch[0]=null;
root->maintain();
return;
}
if(b==null){
Splay(a,null);
if(root->ch[1]->cnt>1){
root->ch[1]->cnt--;root->ch[1]->siz--;
root->ch[1]->sum-=root->ch[1]->x;
}
else root->ch[1]=null;
root->maintain();
return;
}
Splay(a,null);Splay(b,a);
if(root->ch[1]->ch[0]->cnt>1){
root->ch[1]->ch[0]->cnt--;root->ch[1]->ch[0]->siz--;
root->ch[1]->ch[0]->sum-=root->ch[1]->ch[0]->x;
}
else root->ch[1]->ch[0]=null;
root->ch[1]->maintain();
root->maintain();
return;
}
Node *K_th(int k,Node *x){
if(k>x->ch[0]->siz && k<=x->ch[0]->siz+x->cnt)
return x;
int d=k<=x->ch[0]->siz?0:1;
return K_th(k-(d?x->ch[0]->siz+x->cnt:0),x->ch[d]);
}
int n,k;
long long ans,t;
int a[N];
main(){
init();
n=read();k=read();
for(int i=1;i<=n;i++)
a[i]=read();
for(int i=1;i<=k;i++)
AddNew(a[i]);
tmp=K_th((k+1)>>1,root);
Splay(tmp,null);
ans=root->x*root->ch[0]->siz-root->ch[0]->sum+root->ch[1]->sum-root->x*root->ch[1]->siz;
for(int i=k+1;i<=n;i++){
Delete(a[i-k]);
AddNew(a[i]);
tmp=K_th((k+1)>>1,root);
Splay(tmp,null);
t=root->x*root->ch[0]->siz-root->ch[0]->sum+root->ch[1]->sum-root->x*root->ch[1]->siz;
ans=min(t,ans);
}
printf("%lld",ans);
return 0;
}