首先要知道每次拿走最小才会达到最优,因为最小的不会给其他的提供任何加分,只有可能减小加分。
删除卡片的次序确定了,剩下的就是确定每段区间的左右端点。
pos[i] 表示数字 i 在初始序列中的位置。
首先枚举i (i = 1 -> n),如果不需删除,则将pos[i]放入set<int> S中,如果不需删除,则在S中二分查找上下界。
总的时间复杂度为o( (n-k)*log(k) )。
#include <algorithm>
#include <iostream>
#include <cstring>
#include <cstdlib>
#include <cstdio>
#include <queue>
#include <cmath>
#include <stack>
#include <map>
#include <set>
#include <ctime>
#include <iomanip>
#pragma comment(linker,"/STACK:1024000000");
#define EPS (1e-6)
#define LL long long
#define ULL unsigned long long
#define INF 0x3f3f3f3f
#define Mod 1000000007
#define mod 1000000007
/** I/O Accelerator Interface .. **/
#define g (c=getchar())
#define d isdigit(g)
#define p x=x*10+c-'0'
#define n x=x*10+'0'-c
#define pp l/=10,p
#define nn l/=10,n
template<class T> inline T& RD(T &x)
{
char c;
while(!d);
x=c-'0';
while(d)p;
return x;
}
template<class T> inline T& RDD(T &x)
{
char c;
while(g,c!='-'&&!isdigit(c));
if (c=='-')
{
x='0'-g;
while(d)n;
}
else
{
x=c-'0';
while(d)p;
}
return x;
}
inline double& RF(double &x) //scanf("%lf", &x);
{
char c;
while(g,c!='-'&&c!='.'&&!isdigit(c));
if(c=='-')if(g=='.')
{
x=0;
double l=1;
while(d)nn;
x*=l;
}
else
{
x='0'-c;
while(d)n;
if(c=='.')
{
double l=1;
while(d)nn;
x*=l;
}
}
else if(c=='.')
{
x=0;
double l=1;
while(d)pp;
x*=l;
}
else
{
x=c-'0';
while(d)p;
if(c=='.')
{
double l=1;
while(d)pp;
x*=l;
}
}
return x;
}
#undef nn
#undef pp
#undef n
#undef p
#undef d
#undef g
using namespace std;
int num[1000010];
int pos[1000010];
bool ap[1000010];
int st[4001000];
set<int> s;
int Init(int site,int l,int r)
{
if(l == r)
return st[site] = 1;
int mid = (l+r)>>1;
return st[site] = Init(site<<1,l,mid) + Init(site<<1|1,mid+1,r);
}
int Query(int site,int L,int R,int l,int r)
{
if(L == l && R == r)
return st[site];
int mid = (L+R)>>1;
if(r <= mid)
return Query(site<<1,L,mid,l,r);
if(mid < l)
return Query(site<<1|1,mid+1,R,l,r);
return Query(site<<1,L,mid,l,mid) + Query(site<<1|1,mid+1,R,mid+1,r);
}
void Update(int site,int l,int r,int x)
{
if(l == r)
{
st[site] = 0;
return ;
}
int mid = (l+r)>>1;
if(x <= mid)
Update(site<<1,l,mid,x);
else
Update(site<<1|1,mid+1,r,x);
st[site] = st[site<<1] + st[site<<1|1];
}
int main()
{
int n,k,i,j,x;
scanf("%d %d",&n,&k);
for(i = 1;i <= n; ++i)
scanf("%d",&num[i]),pos[num[i]] = i;
memset(ap,false,sizeof(ap));
for(i = 1;i <= k; ++i)
scanf("%d",&x),ap[x] = true;
set<int>::iterator it;
LL sum = 0;
Init(1,1,n);
s.insert(n+1);
s.insert(0);
for(i = 1;i <= n; ++i)
{
if(ap[i])
{
s.insert(pos[i]);
continue;
}
it = s.upper_bound(pos[i]);
int r = *it-1;
int l = *(--it)+1;
sum += Query(1,1,n,l,r);
Update(1,1,n,pos[i]);
}
cout<<sum<<endl;
return 0;
}