题意就不说了吧。
我们把题目模型找出来:
把ai在b数组出现的下标记作i号节点的目标节点。我们把目标节点连起来就构成了若干个环。每次我们有两种操作
交换两个元素来两个环合并。
交换一个元素直接将一个环消去(当然得花费代价)。
当然每次交换元素总有代价,要求最后把所有环消去的代价最小代价最小。
毋庸置疑我们一定是完成操作1之后再对所有环进行操作2。
第二个操作的代价很好处理,直接选出其中w最下的节点以此为断点将环消去,yy一下这一定是消去一个环的最优方法,此时的代价可计算为cir_sum+cir_min*(cir_len-2)。
下面我们来讨论一下操作1。假设我们有两个环i,j(假设cir_min_i<cir_min_j),那么两个环合起来消去的代价比分开消去的代价小,一定满足:
cir_sum_i+cir_min_i*(cir_len_i-2)+cir_sum_j+cir_min_j*(cir_len_j-2)>cir_sum_i+cir_sum_j+cir_min_i*(cir_len_i+cir_len_j-2) + cir_min_i+ cir_min_j
//cir_sum表示换上所有点的w和。cir_min表示环上点的最小w。cir_len表示换上点的个数
我们把题目模型找出来:
把ai在b数组出现的下标记作i号节点的目标节点。我们把目标节点连起来就构成了若干个环。每次我们有两种操作
交换两个元素来两个环合并。
交换一个元素直接将一个环消去(当然得花费代价)。
当然每次交换元素总有代价,要求最后把所有环消去的代价最小代价最小。
毋庸置疑我们一定是完成操作1之后再对所有环进行操作2。
第二个操作的代价很好处理,直接选出其中w最下的节点以此为断点将环消去,yy一下这一定是消去一个环的最优方法,此时的代价可计算为cir_sum+cir_min*(cir_len-2)。
下面我们来讨论一下操作1。假设我们有两个环i,j(假设cir_min_i<cir_min_j),那么两个环合起来消去的代价比分开消去的代价小,一定满足:
cir_sum_i+cir_min_i*(cir_len_i-2)+cir_sum_j+cir_min_j*(cir_len_j-2)>cir_sum_i+cir_sum_j+cir_min_i*(cir_len_i+cir_len_j-2) + cir_min_i+ cir_min_j
//cir_sum表示换上所有点的w和。cir_min表示环上点的最小w。cir_len表示换上点的个数
化简得cir_min_j*(cir_len_j-3)>cir_min_i*cir_len_j+ cir_min_i
很明显我们取的环i一定是cir_min最小的。这样我们找到合并好所有的环之后再消去环即可。
附上代码
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
typedef long long LL;
const int Maxn=1000005;
LL ans,cnt;
int w[Maxn],a[Maxn],lk[Maxn],fa[Maxn];
int tot,minx,i,n,x,t,s;
bool v[Maxn],flag[Maxn];
struct arr
{
int m,len;
bool operator <(const arr &a)const
{ return m<a.m; }
} Cir[Maxn];
int read(){
char ch=getchar();
int ret=0;
while (ch<'0'||ch>'9') ch=getchar();
while (ch>='0'&&ch<='9')
{ret=ret*10+ch-'0';ch=getchar();}
return ret;
}
int main(){
n=read();
for (i=1;i<=n;i++) {w[i]=read();ans+=(LL)w[i];}
for (i=1;i<=n;i++){
a[i]=read();
lk[ a[i] ]=i;
}
for (i=1;i<=n;i++){
x=read();
fa[ lk[x] ]=i;
}
for (i=1;i<=n;i++)
if (!v[i]){
for (t=i,s=cnt=0,minx=Maxn;!v[t];t=fa[t],s++){
if (w[a[t]]<minx) minx=w[a[t]];
v[t]=1;
}
Cir[tot].m=minx;
Cir[tot++].len=s;
}
sort(Cir,Cir+tot);
for (i=1;i<tot;i++)
if (Cir[i].m*(Cir[i].len-3)>Cir[0].m+Cir[0].m*Cir[i].len){
flag[i]=1; Cir[0].len+=Cir[i].len; ans+=(LL)Cir[0].m+Cir[i].m;
}
for (i=0;i<tot;i++)
if (!flag[i]) ans+=(LL)Cir[i].m*(Cir[i].len-2);
printf("%I64d\n",ans);
return 0;
}