题意
你有一个长度为 n n n 个队列,你每次取出这个队列的队首,如果这个数是队列中最小的数那么就把他丢掉,否则就放在队列的队尾。
解题方法
我们首先就想到了暴力的做法,每次枚举第一个然后扔到后面去,复杂度是 o ( n 2 ) o(n^2) o(n2)的,发现根本过不了,于是我们就要想到一些优化,我们发现这个队列就和一个环一样,所以我们先把原来的 a a a数组复制一遍,因为我们每次取出下一个被扔走的数一定是在 [ l a s t + 1 , l a s t + n − 1 ] [last+1,last+n-1] [last+1,last+n−1]这个区间内的,我们就就可以直接用线段树求出这个最小值得位置就可以了,取出这个数以后就把这个位置和位置 + n +n +n的值都变成一个很大数,然后我们该怎么样统计答案呢?答案不能是简单的现在的位置减去上去的位置,因为当中有些数已经被扔出去了,不能再重复计数了,我们就再开一个数据结构,查询 [ l a s t + 1 , n o w − 1 ] [last+1,now-1] [last+1,now−1]中有多少数已经出去了,对于每次扔出去的数我们在那个位置和那个位置 + n +n +n的地方都变成1,就可以解决了。
代码
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
int a[200500];
int n;
struct node
{
int l,r;
int minn,pos;
friend node operator + (node x,node y)
{
node lans;
if(x.minn<=y.minn)
{
lans.minn=x.minn;
lans.pos=x.pos;
}
else
{
lans.minn=y.minn;
lans.pos=y.pos;
}
return lans;
}
}t[200500<<2];
void pushup(int rt)
{
if(t[rt<<1].minn<=t[rt<<1|1].minn)
{
t[rt].minn=t[rt<<1].minn;
t[rt].pos=t[rt<<1].pos;
}
else
{
t[rt].minn=t[rt<<1|1].minn;
t[rt].pos=t[rt<<1|1].pos;
}
}
void build(int rt,int l,int r)
{
t[rt].l=l;
t[rt].r=r;
if(l==r)
{
t[rt].minn=a[l];
t[rt].pos=l;
return;
}
int mid=l+r>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
pushup(rt);
return;
}
void change(int rt,int x,int d)
{
if(t[rt].l==t[rt].r&&t[rt].l==x)
{
t[rt].minn=d;
return;
}
int mid=t[rt].l+t[rt].r>>1;
if(x<=mid)change(rt<<1,x,d);
else change(rt<<1|1,x,d);
pushup(rt);
}
node query(int rt,int l,int r)
{
if(t[rt].l>=l&&t[rt].r<=r)
{
node lans;
lans.minn=t[rt].minn;
lans.pos=t[rt].pos;
return lans;
}
int mid=t[rt].l+t[rt].r>>1;
node lans;
if(l<=mid)
{
lans=query(rt<<1,l,r);
if(mid<r)
lans=lans+query(rt<<1|1,l,r);
}
else
{
if(mid<r)
lans=query(rt<<1|1,l,r);
}
return lans;
}
int c[1000000];
int lowbit(int x)
{
return x&(-x);
}
int update(int x,int v)
{
while(x<=2*n)
{
c[x]+=v;
x+=lowbit(x);
}
}
int sum(int x)
{
int res=0;
while(x>0)
{
res+=c[x];
x-=lowbit(x);
}
return res;
}
int main()
{
scanf("%d",&n);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]),a[n+i]=a[i];
build(1,1,2*n);
int last=0;
long long ans=0;
for(int i=1;i<=n;i++)
{
node lans;
if(i==1)
{
lans=query(1,1,n);
}
else
lans=query(1,last+1,n+last-1);
int x=lans.pos;
ans+=x-last-(sum(x)-sum(last));
if(x>n)last=x-n;else last=x;
change(1,last,1e7);change(1,last+n,1e7);
update(last,1);update(last+n,1);
}
cout<<ans<<'\n';
return 0;
}