给出一个长度为N的序列,每一次我们操作可以选择一段长度为Len的区间,然后将其中最小的那个数删除掉,并且获得Len的价值,我们希望最终数组变成长度为M的那个序列,问我们过程中进行操作能够获得的最多的价值为多少。
保证每个数都只会出现1次。
思路:
①我们肯定是从最小的数开始删起是最优的,因为每一次我们选择一个最小的数进行删除的时候,能够拓展出来的这个长度为Len的区间长度会最大,所以我们考虑先从最小的需要被删除的数字开始删除。
②那么每一次我们取出当前最小需要删除的数,因为每个数只会出现一次,那么定位这个数原来在树中的位子是O(1)可以很容易做到的,将这个位子找到之后,我们向左拓展到最远处,使得最远处到当前这个数的位子的这个区间:【PosL,now】中,没有数字小于这个数。然后再向右拓展到最远处,使得最远处到当前这个数的位子的这个区间:【now,PosR】中,没有数字小于这个数。
因为我们删除的过程是从小到大的,所以我们这里无需考虑原数组中需要被删除的数字的情况,这些数即使在区间【PosL,now】中并且小于a【now】,我们也无需考虑。因为在此操作之前,我们已经将这个小于a【now】的数字已经删除掉了。
那么这里我们可以处理出来另外一个数组b【】;使得需要被删除的位子上的数b【i】=INF,不需要被删除的位子上的数b【i】=a【i】,那么我们二分的过程在b数组上进行即可。
我们要确定一个区间中是否包含一个数小于a【now】,我们还需要预处理一个区间RMQ,预处理一个ST表即可。
③然后我们每一次找到区间并且统计了值之后,需要一颗树状数组来维护哪些位子上还有数,哪些位子上的数已经删除了,用于统计答案过程所用。
问题稍微有些复杂,口述有很多不恰当的描述,具体参考代码理解即可,每一部分都是独立的,很好理解。
注意卡常问题= =
Ac代码(1990+ms过的):
#include<stdio.h>
#include<string.h>
#include<math.h>
#include<algorithm>
using namespace std;
struct node
{
int val,pos;
}del[1450000];
int poww[54];
int Len[1450000];
int vis[1450000];
int a[1450000];
int b[1450000];
int n,m;
int cmp(node a,node b)
{
return a.val<b.val;
}
/**************************************/
int tree[1540005];
int lowbit(int x)
{
return x&(-x);
}
int sum(int x)
{
int sum=0;
while(x>0)
{
sum+=tree[x];
x-=lowbit(x);
}
return sum;
}
void add(int x,int c)
{
while(x<=n)
{
tree[x]+=c;
x+=lowbit(x);
}
}
/**************************************/
int minn[1200005][40];
void ST()
{
int len=floor(log10(double(n))/log10(double(2)));
for(int j=1;j<=len;j++)
{
for(int i=1;i<=n+1-(1<<j);i++)
{
minn[i][j]=min(minn[i][j-1],minn[i+(1<<(j-1))][j-1]);
}
}
}
int getminn(int a,int b)
{
int len=Len[b-a+1];
return min(minn[a][len], minn[b-poww[len]+1][len]);
}
int main()
{
poww[0]=1;
for(int i=1;i<=30;i++)
{
poww[i]=poww[i-1]*2;
}
for(int i=0;i<=1000000;i++)
{
Len[i]=floor(log10(double(i))/log10(double(2)));
}
while(~scanf("%d%d",&n,&m))
{
memset(vis,0,sizeof(vis));
memset(tree,0,sizeof(tree));
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
for(int i=1;i<=m;i++)
{
int x;scanf("%d",&x);
vis[x]++;
}
int cnt=0;
for(int i=1;i<=n;i++)
{
if(vis[a[i]]==0)
{
++cnt;
del[cnt].val=a[i];
del[cnt].pos=i;
}
}
for(int i=1;i<=n;i++)
{
if(vis[a[i]]==1)b[i]=a[i];
else b[i]=0x3f3f3f3f;
minn[i][0]=b[i];
}
ST();
__int64 output=0;
sort(del+1,del+1+cnt,cmp);
for(int i=1;i<=n;i++)add(i,1);
int tot=0;
for(int i=1;i<=cnt;i++)
{
int j=del[i].pos;
int valj=del[i].val;
int PosL=-1;
int l=1;
int r=j;
while(r-l>=0)
{
int mid=(l+r)/2;
if(getminn(mid,j)>valj)
{
PosL=mid;
r=mid-1;
}
else l=mid+1;
}
int PosR=-1;
l=j;
r=n;
while(r-l>=0)
{
int mid=(l+r)/2;
if(getminn(j,mid)>valj)
{
PosR=mid;
l=mid+1;
}
else r=mid-1;
}
output+=sum(PosR)-sum(PosL-1);
add(j,-1);
}
printf("%I64d\n",output);
}
}