题目描述
题解
维护一段长度为k的区间的splay,每一次查询中位数即可。
每一个点只会被加入和删除一次,时间复杂度
O(2nlogn)
。
代码
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
#define LL long long
#define N 100005
const LL inf=1e18;
int n,k,root,sz;
LL h[N];
int f[N],ch[N][2];
LL key[N],size[N],sum[N];
LL ans=inf;
void clear(int x)
{
f[x]=ch[x][0]=ch[x][1]=key[x]=size[x]=sum[x]=0;
}
int get(int x)
{
return ch[f[x]][1]==x;
}
void update(int x)
{
size[x]=1;sum[x]=key[x];
if (ch[x][0])
{
size[x]+=size[ch[x][0]];
sum[x]+=sum[ch[x][0]];
}
if (ch[x][1])
{
size[x]+=size[ch[x][1]];
sum[x]+=sum[ch[x][1]];
}
}
void rotate(int x)
{
int old=f[x],oldf=f[old],wh=get(x);
ch[old][wh]=ch[x][wh^1];
if (ch[old][wh]) f[ch[old][wh]]=old;
ch[x][wh^1]=old;
f[old]=x;
if (oldf) ch[oldf][ch[oldf][1]==old]=x;
f[x]=oldf;
update(old);
update(x);
}
void splay(int x)
{
for (int fa;fa=f[x];rotate(x))
if (f[fa])
rotate( (get(x)==get(fa))?fa:x );
root=x;
}
void insert(LL x)
{
if (!root)
{
root=++sz;
size[sz]=1;key[sz]=sum[sz]=x;
return;
}
int now=root,fa=0;
while (1)
{
fa=now;
now=ch[now][x>key[now]];
if (!now)
{
++sz;
f[sz]=fa;ch[fa][x>key[fa]]=sz;
size[sz]=1;key[sz]=sum[sz]=x;
update(fa);
splay(sz);
break;
}
}
}
int find(int x)
{
int now=root;
while (1)
{
if (x<=size[ch[now][0]]) now=ch[now][0];
else
{
x-=size[ch[now][0]];
if (x==1) return now;
--x;
now=ch[now][1];
}
}
}
int pre()
{
int now=ch[root][0];
while (ch[now][1]) now=ch[now][1];
return now;
}
void del(int x)
{
splay(x);
if (!ch[root][0]&&!ch[root][1])
{
clear(root);
root=0;
return;
}
if (!ch[root][0])
{
int oldroot=root;
root=ch[oldroot][1];
f[root]=0;
clear(oldroot);
return;
}
if (!ch[root][1])
{
int oldroot=root;
root=ch[oldroot][0];
f[root]=0;
clear(oldroot);
return;
}
int oldroot=root;
int leftbig=pre();
splay(leftbig);
ch[root][1]=ch[oldroot][1];
f[ch[root][1]]=root;
clear(oldroot);
update(root);
return;
}
int main()
{
scanf("%d%d",&n,&k);
for (int i=1;i<=n;++i) scanf("%lld",&h[i]);
for (int i=1;i<=k;++i) insert(h[i]);
int now=find((k>>1)+1);
splay(now);
LL cost=(size[ch[root][0]]*key[root]-sum[ch[now][0]]+sum[ch[root][1]]-size[ch[root][1]]*key[root]);
ans=min(ans,cost);
for (int i=k+1;i<=n;++i)
{
insert(h[i]);
del(i-k);
int now=find((k>>1)+1);
splay(now);
LL cost=(size[ch[root][0]]*key[root]-sum[ch[now][0]]+sum[ch[root][1]]-size[ch[root][1]]*key[root]);
ans=min(ans,cost);
}
printf("%lld\n",ans);
}