魔法树 / Magic Tree
题目链接:jzoj 7202 / luogu CF1193B / luogu P6847
题目大意
给你一棵树,每个点可能有至多一个果实,然后果实有成熟时间和价值。
你可以选一些果实,满足如果某个果实 A 的位置是另一个果实 B 的位置的祖先,那么 A 的成熟时间不能早于 B 的。
要你最大化并输出选的果实的价值和。
思路
首先不难想到树形 DP。
设
f
i
,
j
f_{i,j}
fi,j 为搞定
i
i
i 的子树,最后采摘时刻为
j
j
j 的最大价值和。
然后不难想到
n
2
n^2
n2 的转移,就是合并两个子树
u
,
v
u,v
u,v。
若两个子树分别是
f
1
f_1
f1,则
f
i
,
j
=
max
k
=
1
j
{
max
{
f
u
j
+
f
v
i
,
f
u
i
+
f
v
j
}
}
f_{i,j}=\max\limits_{k=1}^j\{\max\{f_{u_j}+f_{v_i},f_{u_i}+f_{v_j}\}\}
fi,j=k=1maxj{max{fuj+fvi,fui+fvj}}
然后发现你这个
max
\max
max 可以用线段树来搞,第二维搞一个线段树,就有了
n
n
n 个线段树。
然后你就发现变成了你要把两个线段树的值合并在一起。
可以用启发式合并达到
n
log
2
n
n\log^2n
nlog2n,但是我们这里用的是线段树合并。
普通的线段树合并当然不行,我们就要变形一下。
我们先合并左子树,再合并右子树。
然后我们分别记录现在两个线段树到当前的处理位置左端的最大前缀,然后合并一下就行了。
代码
#include<cstdio>
#include<iostream>
#define ll long long
#define INF 0x3f3f3f3f3f3f3f3f
using namespace std;
struct node {
int to, nxt;
}e[100001];
int n, m, k, fa[100001];
int a[100001], b[100001];
int x, y, z, le[100001], KK;
int rt[100001], tot;
struct Tree {
int l, r;
ll val, lazy;
}t[100001 << 6];
int get_new() {
return ++tot;
}
void add(int x, int y) {
e[++KK] = (node){y, le[x]}; le[x] = KK;
}
void down(int now) {
if (t[now].lazy) {
if (t[now].l) t[t[now].l].lazy += t[now].lazy, t[t[now].l].val += t[now].lazy;
if (t[now].r) t[t[now].r].lazy += t[now].lazy, t[t[now].r].val += t[now].lazy;
t[now].lazy = 0;
}
}
void merge(int &x, int y, ll lsum, ll rsum, int l, int r) {
if (!x && !y) return ;
if (!x) {
t[y].lazy += rsum;
t[y].val += rsum;
x = y;
return ;
}
if (!y) {
t[x].lazy += lsum;
t[x].val += lsum;
return ;
}
if (l == r) {
lsum = max(lsum, t[y].val);
rsum = max(rsum, t[x].val);
t[x].val = max(t[x].val + lsum, t[y].val + rsum);
return ;
}
down(x); down(y);
int mid = (l + r) >> 1;
ll llsum = t[t[x].l].val;//记录左边的全部值
ll rrsum = t[t[y].l].val;
merge(t[x].l, t[y].l, lsum, rsum, l, mid);//先合并左边的
merge(t[x].r, t[y].r, max(lsum, rrsum), max(rsum, llsum), mid + 1, r);//然后再合并右边,记得要先算上左边的值
t[x].val = max(t[t[x].l].val, t[t[x].r].val);
}
void insert(int pl, ll val, int &now, int l, int r) {
if (!now) now = get_new();
t[now].val = max(t[now].val, val);
if (l == r) return ;
down(now);
int mid = (l + r) >> 1;
if (pl <= mid) insert(pl, val, t[now].l, l, mid);
else insert(pl, val, t[now].r, mid + 1, r);
}
ll query(int L, int R, int now, int l, int r) {
if (!now) return 0;
if (L <= l && r <= R) return t[now].val;
down(now);
int mid = (l + r) >> 1;
ll re = -INF;
if (L <= mid) re = max(re, query(L, R, t[now].l, l, mid));
if (mid < R) re = max(re, query(L, R, t[now].r, mid + 1, r));
return re;
}
void dfs(int now) {
for (int i = le[now]; i; i = e[i].nxt) {
dfs(e[i].to);
merge(rt[now], rt[e[i].to], 0, 0, 1, k);
}
insert(a[now], b[now] + max(0ll, query(1, a[now], rt[now], 1, k)), rt[now], 1, k);
}
int main() {
scanf("%d %d %d", &n, &m, &k);
for (int i = 2; i <= n; i++) scanf("%d", &fa[i]), add(fa[i], i);
for (int i = 1; i <= m; i++) {
scanf("%d %d %d", &x, &y, &z);
a[x] = y; b[x] = z;
}
dfs(1);
printf("%lld", t[rt[1]].val);
return 0;
}