这个,把两个栈拼一起,然后拿个指针维护分界点,拿树状数组维护要弹多少个元素即可
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<cmath>
#include<algorithm>
#include<iomanip>
#include<vector>
#include<map>
#include<set>
#include<bitset>
#include<queue>
#include<stack>
using namespace std;
#define MAXN 100010
#define MAXM 1010
#define INF 1000000000
#define MOD 1000000007
#define eps 1e-8
#define ll long long
#define lb(x) x&-x
int n;
int wzh;
int n1,n2;
int c[MAXN];
int p[MAXN];
int a[MAXN];
int tls[MAXN];
map<int,int>h;
void change(int x,int y){
for(;x<=n;x+=lb(x)){
c[x]+=y;
}
}
int ask(int x){
int re=0;
for(;x;x-=lb(x)){
re+=c[x];
}
return re;
}
ll ans;
int main(){
int i,x;
scanf("%d%d",&n1,&n2);
n=n1+n2;
for(i=n1;i;i--){
scanf("%d",&a[i]);
}
for(i=1;i<=n2;i++){
scanf("%d",&a[n1+i]);
}
memcpy(tls,a,sizeof(a));
sort(tls+1,tls+n+1);
for(i=1;i<=n;i++){
h[tls[i]]=i;
}
for(i=1;i<=n;i++){
p[h[a[i]]]=i;
}
for(i=1;i<=n;i++){
change(i,1);
}
wzh=n1;
for(i=n;i;i--){
if(p[i]<=wzh){
ans+=ask(wzh)-ask(p[i]);
change(p[i],-1);
wzh=p[i];
}else{
ans+=ask(p[i]-1)-ask(wzh);
change(p[i],-1);
wzh=p[i]-1;
}
}
printf("%lld\n",ans);
return 0;
}
/*
*/