题目描述
Q市发生了一起特大盗窃案。这起盗窃案是由多名盗窃犯联合实施的,你要做的就是尽可能多的抓捕盗窃犯。
已知盗窃犯分布于
N
N
N个地点,以及第
i
i
i个地点初始有
a
i
a_i
ai名盗窃犯。
特别的是,对于每一个地点
u
u
u,都有一个固定的地点
v
v
v–当前如果某个盗窃犯位于地点
u
u
u,在下一个时刻他会移动到地点
v
v
v。
你需要通过初始时在某些点设置哨卡来捉住他们。
现在你可以在
M
M
M个地点设置哨卡,如果在某个地点设置哨卡,你可以抓获在任一时刻经过该地点的盗窃犯。
也就是说,哨卡存在的时间是无限长,但哨卡不能移动。
输入描述
第一行两个整数
N
N
N,
M
M
M(1
≤
\leq
≤
N
N
N,
M
M
M
≤
\leq
≤
1
0
5
10^5
105)。
第二行
N
N
N个整数,
a
1
,
a
2
.
.
.
a
n
a_1,a_2...a_n
a1,a2...an(0
≤
\leq
≤
a
1
,
a
2
,
.
.
.
a
n
a_1,a_2,...a_n
a1,a2,...an
≤
\leq
≤
1
0
5
10^5
105),表示第
i
i
i个地方初始有
a
i
a_i
ai名盗窃犯。
第三行
N
N
N个整数
v
1
,
v
2
,
v
3
.
.
.
v
n
v_1,v_2,v_3...v_n
v1,v2,v3...vn(1
≤
\leq
≤
v
1
,
v
2
,
v
3
,
.
.
.
v
n
v_1,v_2,v_3,...v_n
v1,v2,v3,...vn
≤
\leq
≤ N),表示当时处于地点i的盗窃犯下一个时刻会移动到地点
v
i
v_i
vi。
输出描述
输出一行一个整数–能够抓捕到的最大数量。
输入 |
---|
8 2 |
1 2 3 4 1 2 3 12 |
2 3 3 3 6 7 5 8 |
输出 |
22 |
说明 |
在地点3、地点8分别设置一个哨卡,此时答案为1+2+3+4+12=22 |
输入 |
---|
8 2 |
1 2 3 4 5 6 7 8 |
2 3 4 5 6 7 8 8 |
输出 |
36 |
说明 |
在地点2、地点8分别设置一个哨卡,此时答案为1+2+3+4+5+6+7+8=36 |
题目大意
第 i i i个地点初始有 a i a_i ai名盗窃犯,每个点 u i u_i ui的盗窃犯在下一秒钟会移动到 v i v_i vi,在 N N N个点中选择 M M M个点,使得无限时间之后经过这 M M M个点的盗窃犯最多。
解题思路
最开始的时候拿到这道题都没什么思路,但仔细一想发现这是一个分堆问题,每个点下跳的那个点可以看作是它的父亲结点(最好情况所有结点都在一起),并且可以解决有环的情况。那么,通过并查集分堆之后,在同一堆的所有结点经过无限时间之后,肯定会到达根结点(即我们应该设置哨卡的位置)。然后,统计每一个分堆的盗窃犯数量之后进行排序,前 m m m个数量之和即为答案。
AC代码
#include <bits/stdc++.h>
#define INF 0x3f3f3f
using namespace std;
const int mod=1e9+7;
const int Max_N=1e5+4;
typedef pair<int,int>P;
typedef long long ll;
typedef unsigned long long ull;
int pre[Max_N];
bool cmp(int a,int b)
{
return a>b;
}
void init()
{
for(int i=1;i<=Max_N;i++)
pre[i]=i;
}
int find(int x)
{
int r=x;
while(r!=pre[r])
{
r=pre[r];
}
int i=x,j;
while(i!=pre[i])
{
j=pre[i];
pre[i]=r;
i=j;
}
return r;
}
void join(int x,int y)
{
int fx=find(x);
int fy=find(y);
if(fx!=fy)
pre[fx]=fy;
}
int main(int argc, char const *argv[])
{
int n,m;
init();
int a[Max_N],v[Max_N];
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
for(int i=1;i<=n;i++)
scanf("%d",&v[i]);
for(int i=1;i<=n;i++)
{
join(i,v[i]);
}
ll res[Max_N];
memset(res,0,sizeof(res));
for(int i=1;i<=n;i++)
{
int mid=find(i);
res[mid]+=a[i];
}
sort(res,res+Max_N,cmp);
ll result=0;
for(int i=0;i<m;i++)
result+=res[i];
printf("%lld\n",result);
return 0;
}
总结
准确理解题目意思,精准的转换为已知的数学模型。