换根解决的是“不定根”的树形dp问题。该类题目的特点是:给定一个树形结构,需要以每个节点为根进行一系列统计。
方法为两次扫描来求解:
第一次扫描时,任选一个点为根,在“有根树”上执行一次树形dp,在回溯时,自底向上的状态转移。
第二次扫描时,从第一次选的根出发,对整根树执行一个dfs,在每次递归前进行自顶向下的转移,计算出换根后的解。
题目:
有一个树形的水系,由 N-1 条河道和 N 个交叉点组成。
我们可以把交叉点看作树中的节点,编号为 1~N,河道则看作树中的无向边。
每条河道都有一个容量,连接 x 与 y 的河道的容量记为 c(x,y)。
河道中单位时间流过的水量不能超过河道的容量。
有一个节点是整个水系的发源地,可以源源不断地流出水,我们称之为源点。
除了源点之外,树中所有度数为 1 的节点都是入海口,可以吸收无限多的水,我们称之为汇点。
也就是说,水系中的水从源点出发,沿着每条河道,最终流向各个汇点。
在整个水系稳定时,每条河道中的水都以单位时间固定的水量流向固定的方向。
除源点和汇点之外,其余各点不贮存水,也就是流入该点的河道水量之和等于从该点流出的河道水量之和。
整个水系的流量就定义为源点单位时间发出的水量。
在流量不超过河道容量的前提下,求哪个点作为源点时,整个水系的流量最大,输出这个最大值。
思路
第一步:我们任选一个点root为根,d[i]表示以i为根的子树的最大流量,那么d[i]就可以从它的儿子节点转移过来。如果儿子节点t不是叶子节点,那么d[i] += min(d[t],c(i,t)),否则d[i] += c(i,t);
第二步:f[i]表示以i为根的树的最大流量,从第一步我们求出了以某个点为根节点的子树的最大流量,我们考虑是否能O(1)通过d数组来转移f数组呢?显然刚才的f[root] = d[root],考虑它的子树t的转移。f[t]包含了两部分:一部分为从t流向以t为根的子树的流量,为d[t];另一部分为t沿着父节点x的河道,流向别的地方的流量。那么f[x] - min(d[t],c(x,t))就是x流向除t以外的流量,那么min(f[x] - min(d[t],c(x,t)),c(x,t))就是t沿着父节点x的河道,流向别的地方的流量,对于叶子节点判断一下即可。
#include <iostream>
#include <vector>
#include <cstring>
using namespace std;
struct node{
int num,val;
node(int a,int b)
{
num = a;
val = b;
}
};
int deg[200005];
vector<node> g[200005];
int d[200005];
int f[200005];
int vis[200005];
void dfs1(int x)
{
vis[x] = 1;
for (int i = 0; i < g[x].size(); i++)
{
int t = g[x][i].num;
if( vis[t] ) continue;
dfs1(t);
if( deg[t] == 1 ) d[x] += g[x][i].val;
else d[x] += min(d[t],g[x][i].val);
}
}
void dfs2(int x)
{
vis[x] = 1;
for (int i = 0; i < g[x].size(); i++)
{
int t = g[x][i].num;
if( vis[t] ) continue;
if( deg[x] == 1 ) f[t] = d[t] + g[x][i].val;
else f[t] = d[t] + min(f[x] - min(d[t],g[x][i].val),g[x][i].val);
dfs2(t);
}
}
int main()
{
int t;
cin >> t;
while( t-- )
{
memset(deg,0,sizeof(deg));
memset(d,0,sizeof(d));
memset(f,0,sizeof(f));
memset(vis,0,sizeof(vis));
int n;
cin >> n;
for (int i = 1; i <= n; i++) g[i].clear();
for (int i = 1; i < n; i++)
{
int x,y,v;
cin >> x >> y >> v;
g[x].push_back(node(y,v));
g[y].push_back(node(x,v));
deg[x] ++;
deg[y] ++;
}
dfs1(1);
memset(vis,0,sizeof(vis));
f[1] = d[1];
dfs2(1);
int ans = 0;
for (int i = 1; i <= n; i++)
{
ans = max(ans,f[i]);
}
cout << ans << endl;
}
return 0;
}