思路:应该是从最小的数开始删起,那么l[i],r[i]数组记录某个数删的时候,可以涉及到左边右边的边界,然后树状数组再得到其中已经删去的点。
AC代码如下:
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#include<vector>
#define M 1000010
using namespace std;
int l[M],r[M],l2[M],r2[M],num1[M],num2[M],pos[M],vis[M],tree[M],n,m;
long long ans=0;
int lowbit(int x)
{ return x&(-x);}
void update(int x)
{ for(;x<=n;x+=lowbit(x))
tree[x]++;
}
int sum(int x)
{ int ret=0;
for(;x>0;x-=lowbit(x))
ret+=tree[x];
return ret;
}
int main()
{ int i,j,k,len;
scanf("%d%d",&n,&m);
for(i=1;i<=n;i++)
{ scanf("%d",&num1[i]);
l[i]=i;r[i]=i;l2[i]=i;r2[i]=i;
pos[num1[i]]=i;
}
for(i=1;i<=m;i++)
{ scanf("%d",&num2[i]);
vis[num2[i]]=1;
}
vis[0]=2;vis[n+1]=2;
for(i=1;i<=n;i++)
if(vis[num1[i]]==0)
l2[i]=l2[i-1];
r2[1+n]=1+n;
for(i=n;i>=1;i--)
if(vis[num1[i]]==0)
r2[i]=r2[i+1];
for(i=1;i<=n;i++)
while(true)
{ if(num1[l[i]-1]>=num1[i])
l[i]=l[l[i]-1];
else if(vis[num1[l[i]-1]]==0)
l[i]=l2[l[i]-1]+1;
else
break;
}
for(i=n;i>=1;i--)
while(true)
{ if(num1[r[i]+1]>=num1[i])
r[i]=r[r[i]+1];
else if(vis[num1[r[i]+1]]==0)
r[i]=r2[r[i]+1]-1;
else
break;
}
for(i=1;i<=n;i++)
if(vis[i]==0)
{ k=pos[i];
ans+=r[k]-l[k]+1-sum(r[k])+sum(l[k]-1);
update(k);
}
printf("%I64d\n",ans);
}