思路;
直接缩点然后贪心走,注意细节
c o d e code code
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
using namespace std;
long long n, m, tot, cnt, tmp, tot1, k, tot2;
long long stack[1001000], top, ru[1001000], bz[1001000];
long long a[1001000], head[1001000], head1[1000100], head2[1000100];
long long dfn[1001000], low[1001000];
long long d[1001000], c[1001000], siz[1001000];
bool v[1001000];
long long fa[1010000];
struct node
{
long long from, to, next;
}b[1001000], e[1001000], e1[2001000];
struct abc
{
long long an, id;
}ans[1000100];
void add(long long x, long long y)
{
b[++tot]=(node){x, y, head[x]};
head[x]=tot;
}
void add1(long long x, long long y)
{
e[++tot1]=(node){x, y, head1[x]};
head1[x]=tot1;
}
void add2(long long x, long long y)
{
e1[++tot2]=(node){x, y, head2[x]};
head2[x]=tot2;
e1[++tot2]=(node){y, x, head2[y]};
head2[y]=tot2;
}
void tarjan(long long x)
{
dfn[x]=low[x]=++cnt;
v[x]=1;
stack[++top]=x;
for(long long i=head[x]; i; i=b[i].next)
{
long long y=b[i].to;
if(!dfn[y])
{
tarjan(y);
low[x]=min(low[x], low[y]);
}
else if(v[y])
low[x]=min(low[x], dfn[y]);
}
if(dfn[x]==low[x])
{
tmp++;
do
{
c[stack[top]]=tmp;
//cout<<c[stack[top]]<<' '<<stack[top]<<endl;
d[tmp]+=a[stack[top]];
v[stack[top]]=0;
top--;
}
while(x!=stack[top+1]);
}
return;
}
long long fi(long long x)
{
if(fa[x]==x)
return x;
return fa[x]=fi(fa[x]);
}
long long dfs(long long x)
{
if(ans[x].an)
return ans[x].an;
long long suma=0;
for(long long i=head1[x]; i; i=e[i].next)
{
long long y=e[i].to, fx=fi(x), fy=fi(y);
if(siz[fx]>siz[fy])
siz[fx]+=siz[fy], fa[fy]=fx;
else
siz[fy]+=siz[fx], fa[fx]=fy;
suma=max(suma, dfs(y));
}
return ans[x].an=suma+d[x];
}
/*
void dfs1(long long x, long long flag)
{
fa[x]=flag;
cout<<x<<endl;
for(long long i=head2[x]; i; i=e1[i].next)
{
long long y=e1[i].to;
if(fa[y]==0)
dfs1(y, flag);
}
}
*/
bool cmp(abc x, abc y)
{
return x.an<y.an;
}
int main()
{
// freopen("azeroth.in", "r", stdin);
// freopen("azeroth.out", "w", stdout);
scanf("%lld%lld", &n, &m);
for(long long i=1; i<=m; i++)
{
long long x, y;
scanf("%lld%lld", &x, &y);
add(x, y);
}
for(long long i=1; i<=n; i++)
scanf("%lld", &a[i]);
scanf("%lld", &k);
for(long long i=1; i<=n; i++)
{
if(!dfn[i])
tarjan(i);
}
for(long long i=1; i<=n; i++)
{
for(long long j=head[i]; j; j=b[j].next)
{
long long y=b[j].to;
if(c[i]!=c[y])
add1(c[i], c[y]), ru[c[y]]++, add2(c[i], c[y]);
}
}
for(long long i=1; i<=tmp; i++)
fa[i]=i, siz[i]=1;
//dfs_(100, 0);
for(long long i=1; i<=tmp; i++)
{
if(ru[i]==0)
dfs(i);
}
for(long long i=1; i<=tmp; i++)
ans[fi(i)].an=max(ans[fi(i)].an, ans[i].an), ans[i].id=i;
sort(ans+1, ans+1+tmp, cmp);
long long aans=0;
memset(v, 0, sizeof(v));
long long i=tmp;
k++;
while(i>=1&&k!=0)
{
if(v[fi(fa[ans[i].id])]==0)
aans+=ans[i].an, v[fi(fa[ans[i].id])]=1, i--, k--;
else
i--;
}
printf("%lld", aans);
return 0;
}