题目大意:给定一个长度为n的序列,求一个长度为k的子区间,将这个长度为k的区间变成一样的,代价总和最小,求最小花销
显然选取的是这k个数的中位数时代价总和最小
于是我们从左往右扫一遍 用一个Treap来维护这个长度为k的区间即可
时间复杂度O(nlogn) 这水题居然还贡献了一个WA真是。。。
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define M 100100
#define SIZE(p) ((p)?(p)->size:0)
#define SUM(p) ((p)?(p)->sum:0)
using namespace std;
struct Treap{
Treap *ls,*rs;
int val,key;
int cnt,size;
long long sum;
void* operator new (size_t,int _)
{
static Treap mempool[M],*C=mempool;
C->ls=C->rs=0x0;
C->sum=C->val=_;
C->key=rand();
C->cnt=C->size=1;
return C++;
}
void Push_Up()
{
sum=(long long)val*cnt;
size=cnt;
if(ls) sum+=ls->sum,size+=ls->size;
if(rs) sum+=rs->sum,size+=rs->size;
}
friend void Zig(Treap *&x)
{
Treap *y=x->ls;
x->ls=y->rs;
y->rs=x;x=y;
x->Push_Up();
x->rs->Push_Up();
}
friend void Zag(Treap *&x)
{
Treap *y=x->rs;
x->rs=y->ls;
y->ls=x;x=y;
x->Push_Up();
x->ls->Push_Up();
}
friend void Insert(Treap *&x,int y)
{
if(!x)
{
x=new (y)Treap;
return ;
}
if(y==x->val)
x->cnt++;
else if(y<x->val)
{
Insert(x->ls,y);
if(x->ls->key>x->key)
Zig(x);
}
else
{
Insert(x->rs,y);
if(x->rs->key>x->key)
Zag(x);
}
x->Push_Up();
}
friend void Delete(Treap *&x,int y)
{
if(y<x->val)
Delete(x->ls,y);
else if(y>x->val)
Delete(x->rs,y);
else if(x->cnt>=2)
--x->cnt;
else if(!x->ls)
x=x->rs;
else if(!x->rs)
x=x->ls;
else
{
Zag(x);
Delete(x->ls,y);
if(x->ls&&x->ls->key>x->key)
Zig(x);
}
if(x) x->Push_Up();
}
friend int Get_Kth(Treap *x,int y)
{
if(y<=SIZE(x->ls))
return Get_Kth(x->ls,y);
if(y<=SIZE(x->ls)+x->cnt)
return x->val;
else
return Get_Kth(x->rs,y-SIZE(x->ls)-x->cnt);
}
friend long long Query(Treap *x,int y)
{
if(!x) return 0;
if(y<=SIZE(x->ls))
return Query(x->ls,y);
if(y<=SIZE(x->ls)+x->cnt)
return SUM(x->ls) + (long long)x->val*(y-SIZE(x->ls));
else
return SUM(x->ls) + (long long)x->val*x->cnt + Query(x->rs,y-SIZE(x->ls)-x->cnt);
}
}*root;
int n,k,a[M];
long long ans=0x3f3f3f3f3f3f3f3fll,sum[M];
int main()
{
int i;
srand(19980402);
cin>>n>>k;
for(i=1;i<=n;i++)
{
scanf("%d",&a[i]);
sum[i]=sum[i-1]+a[i];
}
for(i=1;i<k;i++)
Insert(root,a[i]);
for(i=k;i<=n;i++)
{
Insert(root,a[i]);
long long mid=Get_Kth(root,k+1>>1);
long long lesser=Query(root,k+1>>1);
long long greater=sum[i]-sum[i-k]-lesser;
ans=min(ans, mid*(k+1>>1)-lesser + greater-mid*(k-(k+1>>1)) );
Delete(root,a[i-k+1]);
}
cout<<ans<<endl;
return 0;
}