乐滋滋在wc上讲的题
对于这类最优化问题,有一个套路是考虑答案的下界,然后判看能不能到达这个下界
首先一个显然的下界是
∑|i−ai|
∑
|
i
−
a
i
|
,但注意到这个下界不够紧,比如序列 1 0 3 2,0和3之间的间隔至少要跨越2次
建一个图,每本书代表一个点,连出一条有向边到他要去的位置,这个的意思就是指这个图中边不相交的环之间的跨越可能要额外的花费
这样确定的下界对于出发点在1的情况是对的,为什么?手玩一下可以发现,如果环A包含/相交环B(A在左B在右),环A可以在走到环B左端的时候放下书拿起B的书,走完环B回到左端点再拿起A的书继续走,如图
这是一个环AB和环CD相交的情况,我们可以从A走到C,将手里的书和C的书交换,走到D换书,走回C换书,接着走AB的环,这样的花费是贴着下界的
这时一个环AB包含环CD的情况,同样,我们从A走到C换书,走完环CD回到C换书,继续环AB,同样也是贴着下界的
那么为什么出发点不在1的时候不一定对呢?因为1一定是在一个最外围的环,比如上图,如果出发点在环CD上,而环AB包含环CD,那么走的时候不能解决AB,需要支付额外的花费从出发点走到环AB上才能接着按下界走,易证这也是最优的方案,于是问题就变成了出发点走到任意一个最外围的环上的最短路,点i向i-1,i+1连边权为1的边,向同一个环内的相邻点连边权为0的边跑最短路,因为边权只有0/1,最短路可以用bfs,0边放队头1边放队尾O(n)跑
(代码有个地方莓findfa挂掉了,现在fix了…)
我的做法是用并查集合并所有相交的环,然后相邻的环连边权1,跑bfs最短路
code:
#include<set>
#include<map>
#include<deque>
#include<queue>
#include<stack>
#include<cmath>
#include<ctime>
#include<bitset>
#include<string>
#include<vector>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<climits>
#include<complex>
#include<iostream>
#include<algorithm>
#define ll long long
using namespace std;
inline void read(int &x)
{
char c; while(!((c=getchar())>='0'&&c<='9'));
x=c-'0';
while((c=getchar())>='0'&&c<='9') (x*=10)+=c-'0';
}
inline void up(int &a,const int &b){if(a<b)a=b;}
inline void down(int &a,const int &b){if(a>b)a=b;}
const int maxn = 2100000;
int n,c[maxn],S;
int okl[maxn],okr[maxn];
ll ans;
int id[maxn],cnt,L[maxn],R[maxn],dep[maxn];
void mark(int x)
{
while(!id[x])
{
id[x]=cnt;
down(L[cnt],x); up(R[cnt],x);
x=c[x];
}
}
int fa[maxn];
int findfa(const int x){return fa[x]==x?x:fa[x]=findfa(fa[x]);}
int t[maxn],tp;
struct edge{int y,nex;}a[maxn<<1]; int len,fir[maxn];
inline void ins(const int x,const int y){a[++len]=(edge){y,fir[x]};fir[x]=len;}
int dis[maxn];
queue<int>q;
int main()
{
//freopen("tmp.in","r",stdin);
//
read(n); read(S); ++S;
for(int i=1;i<=n;i++)
{
read(c[i]); c[i]++;
ans+=(ll)abs(i-c[i]);
}
for(int i=1;i<=n;i++) okl[i]=okl[i-1]|(c[i]!=i);
for(int i=n;i>=1;i--) okr[i]=okr[i+1]|(c[i]!=i);
for(int i=1;i<=n;i++) if(!id[i])
{
++cnt,fa[cnt]=cnt; L[cnt]=R[cnt]=i;
mark(i);
}
tp=0;
for(int i=1;i<=n;i++)
{
int ii=findfa(id[i]);
if(i==L[ii]) t[++tp]=i;
while(t[tp]>L[ii])
{
int la=t[tp--];
int f1=findfa(ii),f2=findfa(id[la]);
fa[f2]=f1; up(R[f1],R[f2]),down(L[f1],L[f2]);
}
if(i==R[ii]) tp--;
}
for(int i=1;i<=cnt;i++) findfa(i),dis[i]=n+1;
for(int i=1,mxr=0;i<=n;i++)
{
int ii=fa[id[i]];
if(mxr<i)
{
dep[ii]=1;
if(i<=S) ans+=okl[i-1]?2:0;
else ans+=okr[i]?2:0;
}
up(mxr,R[ii]);
}
for(int i=1;i<=n;i++)
{
int ii=fa[id[i]];
if(i>1)
{
int j=fa[id[i-1]];
if(ii!=j) ins(ii,j);
}
if(i<n)
{
int j=fa[id[i+1]];
if(ii!=j) ins(ii,j);
}
}
S=fa[id[S]]; dis[S]=0,q.push(S);
while(!q.empty())
{
const int x=q.front(); q.pop();
for(int k=fir[x],y=a[k].y;k;k=a[k].nex,y=a[k].y) if(dis[y]>dis[x]+1)
dis[y]=dis[x]+1,q.push(y);
}
int mn=n+1;
for(int i=1;i<=cnt;i++) if(dep[i]) down(mn,dis[findfa(i)]);
ans+=mn<<1;
printf("%lld\n",ans);
return 0;
}