题目链接
简单题解
题意很有意思。给一个长度为n的数列,每个数在1-n之间且各不相同。你可以从这个数列中删数:每次选一段区间,可以删除这个区间中最小的那个数,然后每次删除得到的分数是这个区间的长度。题目要你把原序列删成一个规定的长度为k的序列,并要得分最高。
能确认的事情有:要删哪几个数是固定的;删一个数时候选的区间肯定要尽量大,尽量大就是分别找到这个数往左往右第一个比他小的,这中间夹的就是删这个数的时候要选的区间了(因为只能删一个区间里最小的数)。
想一下能发现从小数开始删是最优的,因为删了大数可能会让删小数时候的可选区间变小,而删小数不会减小大数的区间。那么问题就简单了:从小往大找要删的数,找到一个删一个,每次删的区间长度就是往右往左第一个比他小的数中间夹的还没被删的数的个数(这个用BIT求个和就行了)。至于怎么找区间的左右端点,从小到大把数的位置放到一个set里面二分即可。
代码
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<string>
#include<algorithm>
#include<set>
#include<vector>
#include<map>
using namespace std;
#define CLR(a,b) memset(a,b,sizeof(a))
typedef long long ll;
typedef pair<int,int> PII;
#define push_back pb
#define fi first
#define se second
int n,k;
const int maxn = 1000000+20;
int s[maxn];
int p[maxn],b[maxn],at[maxn];
bool mark[maxn];
void add (int idx , int v) {
for (int i = idx ; i <= n ; i += i & -i)
s[i] += v;
}
int sum (int idx) {
int ret = 0;
for (int i = idx ; i > 0 ; i -= i & -i)
ret += s[i];
return ret;
}
set<int> ms;
set<int>::iterator ite;
void solve()
{
ms.insert(0);
ms.insert(n+1);
ll ans = 0;
for(int i=1;i<=n;i++){
if(mark[i]){
ms.insert(at[i]);
continue;
}
ite = ms.upper_bound(at[i]);
int r = *ite - 1;
int l = *(--ite);
ans += sum(r) - sum(l);
add(at[i],-1);
}
printf("%I64d\n",ans);
}
int main()
{
scanf("%d%d",&n,&k);
for(int i=1;i<=n;i++){
scanf("%d",&p[i]);
add(i,1);
at[p[i]]=i;
}
for(int i=1;i<=k;i++){
scanf("%d",&b[i]);
mark[b[i]] = 1;
}
solve();
return 0;
}