动态规划——Rerooting DP
Rerooting DP 方法是解决一类树上DP问题的一般解决方法,与树形DP不同的是,树形DP只根据节点的儿子节点计算得到父节点的答案,而在Rerooting DP法中,将每个节点都看成是该颗树的根节点,根节点在树中转移,因此叫Rerooting DP方法。
原文地址(日文):
Rerooting DP包括两次DFS:
- 计算子树答案(树形DP)。
- 根节点转移。
本文将以ABC222 F为例子,讲解Rerooting DP。
子树答案
定义: m [ u ] m[u] m[u]为从 u u u节点出发,只能到达子孙节点的的答案。可以很容易的通过树形DP的方法计算,因为节点的答案可通过儿子节点的答案计算得出。
void mdp(int u,int r){
for(int ne = head[u];ne;ne = e[ne].nxt){
int to = e[ne].to;
if(to == r) continue;
mdp(to,u);
m[u] = max(m[u], max(m[to], D[to]) + e[ne].cost);
}
}
根节点转移
根节点转移的方法一般格式为:
// u 为当前工作节点,r 为 u 的父节点,v 为 u-r 的贡献
void rdp(int u,int r,ll v);
如何让节点
u
u
u作为根节点以计算答案呢,我们发现,
u
u
u作为根节点向下的答案已经计算完成,即
m
[
u
]
m[u]
m[u]。唯一差的就是
(
u
,
r
)
(u,r)
(u,r)的答案。而上述方法中的参数
v
v
v就是
(
u
,
r
)
(u,r)
(u,r)的答案。
得到了 v v v之后,我们就有了 u u u的完整答案。
ans[u] = max(m[u],v);
考虑递归计算 u u u子节点的答案,如何计算递归方程中的 v v v。
我们发现, u u u的子节点 s s s除了 s s s自己之外,其他的兄弟之间取max即可(包括 ( u , r ) (u,r) (u,r))。
考虑如何计算剔除一个节点的最大值,维护一个前缀最大和后缀最大即可。
void rdp(int u,int r,ll v){
ans[u] = max(m[u],v);
vector<ll> subans;
for(int ne = head[u];ne;ne = e[ne].nxt){
int to = e[ne].to;
if(to == r) continue;
ll c = max(m[to],D[to]) + e[ne].cost;
subans.push_back(c);
}
vector<ll> pmax(subans.size()),smax(subans.size());
for(int i = 1;i < subans.size();i++)
pmax[i] = max(pmax[i - 1],subans[i - 1]);
for(int i = int(subans.size()) - 2;i >= 0;i--)
smax[i] = max(smax[i + 1],subans[i + 1]);
for(int ne = head[u],ptr = 0;ne;ne = e[ne].nxt){
int to = e[ne].to;
if(to == r) continue;
rdp(to,u,e[ne].cost + max(max(D[u],v),max(pmax[ptr],smax[ptr])));
ptr++;
}
}
这样,这个题就可以通过Rerooting DP的方法解决了。
struct Edge{
int to;
int nxt;
ll cost;
} e[400005];
int head[200005];
ll D[200005];
ll m[200005];
ll ans[200005];
int tot;
void add(int u,int v,ll cost){
tot++;
e[tot].to = v;
e[tot].nxt = head[u];
e[tot].cost = cost;
head[u] = tot;
}
void mdp(int u,int r){
for(int ne = head[u];ne;ne = e[ne].nxt){
int to = e[ne].to;
if(to == r) continue;
mdp(to,u);
m[u] = max(m[u], max(m[to], D[to]) + e[ne].cost);
}
}
void rdp(int u,int r,ll v){
ans[u] = max(m[u],v);
vector<ll> subans;
for(int ne = head[u];ne;ne = e[ne].nxt){
int to = e[ne].to;
if(to == r) continue;
ll c = max(m[to],D[to]) + e[ne].cost;
subans.push_back(c);
}
vector<ll> pmax(subans.size()),smax(subans.size());
for(int i = 1;i < subans.size();i++)
pmax[i] = max(pmax[i - 1],subans[i - 1]);
for(int i = int(subans.size()) - 2;i >= 0;i--)
smax[i] = max(smax[i + 1],subans[i + 1]);
for(int ne = head[u],ptr = 0;ne;ne = e[ne].nxt){
int to = e[ne].to;
if(to == r) continue;
rdp(to,u,e[ne].cost + max(max(D[u],v),max(pmax[ptr],smax[ptr])));
ptr++;
}
}
int main(){
FR;
int n;
cin >> n;
for(int i = 0;i < n - 1;i++){
int u,v; ll c;
cin >> u >> v >> c;
add(u,v,c);
add(v,u,c);
}
for(int i = 1;i <= n;i++){
cin >> D[i];
}
mdp(1,1);
rdp(1,1,0);
for(int i = 1;i <= n;i++){
cout << ans[i] << endl;
}
return 0;
}
例题
关于树上最大匹配问题,考虑Rerooting DP。
struct Edge
{
int to;
int nxt;
} e[400005];
int head[200005];
int tot = 0;
void add(int u, int v)
{
tot++;
e[tot].to = v;
e[tot].nxt = head[u];
head[u] = tot;
}
int dp[200005][2];
void tdp(int u, int r)
{
int sum = 0;
for (int ne = head[u]; ne; ne = e[ne].nxt)
{
int to = e[ne].to;
if (to == r)
continue;
tdp(to, u);
sum += max(dp[to][0], dp[to][1]);
}
dp[u][0] = sum;
for (int ne = head[u]; ne; ne = e[ne].nxt)
{
int to = e[ne].to;
if (to == r)
continue;
dp[u][1] = max(dp[u][1], sum - max(dp[to][0], dp[to][1]) + dp[to][0] + 1);
}
}
int cnt = 0;
void rdp(int u, int r, int j, int k)
{
int sum = dp[u][0];
int mx = 0, tmx = 0;
for (int ne = head[u]; ne; ne = e[ne].nxt)
{
int to = e[ne].to;
if (to == r)
continue;
int tag = sum - max(dp[to][0], dp[to][1]) + dp[to][0] + 1;
if (tag >= mx)
{
tmx = mx;
mx = tag;
}
else if (tag >= tmx)
{
tmx = tag;
}
}
for (int ne = head[u]; ne; ne = e[ne].nxt)
{
int to = e[ne].to;
if (to == r)
continue;
int tag = sum - max(dp[to][0], dp[to][1]) + dp[to][0] + 1;
rdp(to, u, sum - max(dp[to][0], dp[to][1]) + max(j, k), max((tag == mx ? tmx : mx) - max(dp[to][0], dp[to][1]) + max(j, k), r != 0 ? (j + 1 + sum - max(dp[to][0], dp[to][1])) : 0));
}
int ans = sum + max(j, k);
if (ans == max(dp[1][1], dp[1][0]))
cnt++;
}
void solve()
{
int n;
cin >> n;
for (int i = 0; i < n - 1; i++)
{
int u, v;
cin >> u >> v;
add(u, v);
add(v, u);
}
tdp(1, 1);
rdp(1, 0, 0, 0);
cout << cnt;
}
int main()
{
FR;
solve();
return 0;
}