杭电多校3 1002
题意:
给定一棵
n
n
n节点的带有点权的树,求任意条不相交的链在树上的最大点权和。
题解:
存在两种情况,对于一个
L
C
A
LCA
LCA及其子树来说,链只存在两种情况,要么经过该点通向其两棵子树,要么不经过该点,仅存在于子树中。
设计
d
p
[
u
]
dp[u]
dp[u]表示该子树的解,
d
p
[
u
]
=
m
a
x
(
∑
d
p
[
v
]
,
s
u
m
[
u
]
+
W
[
v
1
]
+
W
[
v
2
]
−
D
P
[
v
1
]
−
D
P
[
v
2
]
)
dp[u] = max(\sum dp[v], sum[u] + W[v_1] + W[v_2] - DP[v_1] - DP[v_2])
dp[u]=max(∑dp[v],sum[u]+W[v1]+W[v2]−DP[v1]−DP[v2]),其中
W
W
W为子树中最大的两条链的点权和,
s
u
m
[
u
]
sum[u]
sum[u]表示
u
u
u子树内的
d
p
dp
dp总和,
D
P
[
u
]
DP[u]
DP[u]表示子树中最大的两条链的
D
P
DP
DP总和,这两条链可以通过线段树合并来维护,在
m
e
r
g
e
merge
merge操作时,每次取最大的
s
u
m
sum
sum,保证对于一个
L
C
A
LCA
LCA来说,其要合并的子树中的链必然是最大的,同时在叶子节点更新答案,维护最大值,在减去
D
P
DP
DP链和加上
W
W
W链时通过
p
u
s
h
_
d
o
w
n
push\_down
push_down实现区间修改即可。
#include<bits/stdc++.h>
using namespace std;
#define endl '\n'
#define int long long
#define debug(p) for (auto i : p)cerr << i << " "; cerr << endl;
#define debugs(p) for (auto i : p)cerr << i.first << " " << i.second << endl;
typedef pair<int, int> pll;
string yes = "YES";
string no = "NO";
constexpr int N = 2e5 + 7;
int c[N], w[N], dp[N], root[N];
vector<int>edge[N];
int idx, tot, n;
struct Seg_tree{
int s[2];
int sum, add;
void init()
{
s[0] = s[1] = 0;
sum = add = 0;
}
}tr[N << 5];
void push_up(int u)
{
tr[u].sum = tr[tr[u].s[0]].sum + tr[tr[u].s[1]].sum;
}
void push_down(int u, int l, int r)
{
if(tr[u].add)
{
int mid = l + r >> 1;
if(tr[u].s[0])
{
tr[tr[u].s[0]].add += tr[u].add;
// tr[tr[u].s[0]].sum += tr[u].add;
tr[tr[u].s[0]].sum += (mid - l + 1) * tr[u].add;
}
if(tr[u].s[1])
{
tr[tr[u].s[1]].add += tr[u].add;
// tr[tr[u].s[1]].sum += tr[u].add;
tr[tr[u].s[1]].sum += (r - (mid + 1) + 1) * tr[u].add;
}
tr[u].add = 0;
}
}
void modify(int &u, int l, int r, int pos, int x)
{
if(!u) u = ++idx;
// tr[u].sum = x;
if(l == r)
{
tr[u].sum = x;
return;
}
int mid = l + r >> 1;
if(pos <= mid)modify(tr[u].s[0], l, mid, pos, x);
else modify(tr[u].s[1], mid + 1, r, pos, x);
push_up(u);
}
int merge(int u, int v, int l, int r, int &ans)
{
if(!u || !v)return u + v;
if(l == r)ans = max(ans, tr[u].sum + tr[v].sum + tot);
push_down(u, l, r);
push_down(v, l, r);
int mid = l + r >> 1;
tr[u].sum = max(tr[u].sum, tr[v].sum);
tr[u].s[0] = merge(tr[u].s[0], tr[v].s[0], l, mid, ans);
tr[u].s[1] = merge(tr[u].s[1], tr[v].s[1], mid + 1, r, ans);
return u;
}
void dfs(int u, int f)
{
int sum = 0;
for (auto v : edge[u])
{
if(v == f)continue;
dfs(v, u);
sum += dp[v];
}
dp[u] = sum;
tot = sum;
modify(root[u], 1, n, c[u], w[u]);
for (auto v : edge[u])
{
if(v == f)continue;
tr[root[v]].sum -= dp[v];
tr[root[v]].add -= dp[v];
root[u] = merge(root[u], root[v], 1, n, dp[u]);
}
tr[root[u]].sum += tot;
tr[root[u]].add += tot;
}
void solve()
{
cin >> n;
for (int i = 1; i <= n; i++)cin >> c[i];
for (int i = 1; i <= n; i++)cin >> w[i];
for (int i = 1; i < n; i++)
{
int u, v;
cin >> u >> v;
edge[u].push_back(v);
edge[v].push_back(u);
}
dfs(1, -1);
cout << dp[1] << endl;
for (int i = 1; i <= max(n, idx); i++)
{
if(i <= n)
{
edge[i].clear();
root[i] = 0;
}
if(i <= idx)tr[i].init();
}
idx = tot = 0;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0), cout.tie(0);
int T = 1;
cin >> T;
while(T--)
{
solve();
}
}