题面
样例
10 3
2 3 2 4 5 1 2 5 6 2
2 3 4 5 3 7 6 6 7 7
30
分析
(p.s.在这篇文章中,最大链指所有儿子到根节点的路径中最大的那一条)
易想到缩点。
发现缩点后连反边是一棵树(其实是森林)。
于是问题就转化为,给你一棵树,从中找 k k k 个叶子|使他们到根节点的路径|所有节点的并集|的权值和最大,这样每次找权值和最大的那一条即可,找到后直接向上爬每个元素更新,再加个线段树+优先队列,可以做到 O ( k l o g 2 2 ( n ) ) \mathcal {O}(klog_2^2(n)) O(klog22(n))。 (口胡)
想知道更优秀的做法,询问了一下 WJC,发现我们可以预处理出以某一个节点为根的树中的最大链为多少,把森林里的树都以最大链的长度为关键字丢进优先队列里,在找到一个最大链之后,删掉这个链,可以把这棵树分成几棵小树,再放进优先队列里。由于每个节点只会入队一次,出队一次,把大树分成几棵树的时候所有边和点都只会遍历一次,因此这样可以做到 O ( k l o g 2 ( n ) ) \mathcal {O}(klog_2(n)) O(klog2(n))。
Code
//*
#include <cstdio>
#include <algorithm>
#include <cmath>
#include <cstring>
#include <climits>
#include <stack>
#include <vector>
#include <queue>
using namespace std;
const int MAXN = 1e6 + 5, MAXM = 1e6 + 5;
struct Node {
int tp, dn, maxx;
bool operator < (const Node P) const { return maxx < P.maxx; }
Node() {}
Node(int x, int y, int z) { tp = x; dn = y; maxx = z; }
};
int n, k, a[MAXN], col[MAXN], nk, dfn[MAXN], low[MAXN], cnt, val[MAXN];
int Head[MAXN], Ver[MAXM << 1], Next[MAXM << 1], tot, d[MAXN];
int X[MAXN], Y[MAXN], dp[MAXN], num[MAXN], fa[MAXN], ans;
stack <int> s;
vector <int> v[MAXN];
priority_queue <Node> que;
void add(int x, int y) { Ver[++ tot] = y; Next[tot] = Head[x]; Head[x] = tot; }
int Min(int x, int y) { return x < y ? x : y; }
int Max(int x, int y) { return x > y ? x : y; }
void dfs(int x) {
dfn[x] = low[x] = ++ cnt; s.push(x);
for(int i = Head[x]; i; i = Next[i]) {
int Y = Ver[i];
if(!dfn[Y]) {
dfs(Y); low[x] = Min(low[x], low[Y]);
}
else if(!col[Y]) low[x] = Min(low[x], dfn[Y]);
}
if(low[x] == dfn[x]) {
int t; nk ++;
do {
t = s.top(); s.pop(); col[t] = nk;
} while(t != x);
}
}
void dfs1(int x) {
dp[x] = val[x]; num[x] = x;
for(unsigned int i = 0; i < v[x].size(); i ++) {
int Y = v[x][i]; dfs1(Y); fa[Y] = x;
if(dp[Y] + val[x] >= dp[x]) dp[x] = dp[Y] + val[x], num[x] = num[Y];
}
}
void read(int &x) {
x = 0; int f = 1; char c = getchar();
for(; c < '0' || c > '9'; c = getchar()) if(c == '-') f = 0;
for(; c >= '0' && c <= '9'; c = getchar()) x = (x << 1) + (x << 3) + (c ^ 48);
x = f ? x : -x;
}
int main() {
int x;
read(n); read(k);
for(int i = 1; i <= n; i ++) read(a[i]);
for(int i = 1; i <= n; i ++) {
read(x); add(i, x); X[i] = i; Y[i] = x;
}
for(int i = 1; i <= n; i ++) if(!dfn[i]) dfs(i);
for(int i = 1; i <= n; i ++) val[col[i]] += a[i];
for(int i = 1; i <= n; i ++) if(col[X[i]] != col[Y[i]]) v[col[Y[i]]].push_back(col[X[i]]), d[col[X[i]]] ++;
for(int i = 1; i <= nk; i ++) if(!d[i]) dfs1(i), que.push(Node(i, num[i], dp[i]));
for(int i = 1; i <= k; i ++) {
if(que.empty()) break;
Node t = que.top(); que.pop(); ans += t.maxx;
int Last = -1;
while(1) {
for(unsigned int j = 0; j < v[t.dn].size(); j ++) {
int Y = v[t.dn][j];
if(Y == Last) continue;
que.push(Node(Y, num[Y], dp[Y]));
}
Last = t.dn;
if(t.dn == t.tp) break; t.dn = fa[t.dn];
}
}
printf("%d", ans);
return 0;
}