Time Limit: 10 Sec Memory Limit: 256 MB
Description
有一棵点数为n的树,树边有边权。给你一个在0~n之内的正整数m,你要在这棵树中选择m个点,将其染成黑色,并
将其他的n-m个点染成白色。将所有点染色后,你会获得黑点两两之间的距离加上白点两两之间距离的和的收益。
问收益最大值是多少。
Input
第一行两个整数n,m。
接下来n-1行每行三个正整数fr,to,len,表示该树中存在一条长度为len的边(fr,to)。
输入保证所有点之间是联通的。
Output
输出一个正整数,表示收益的最大值。
Sample Input
5 2
1 2 3
1 5 1
2 3 1
2 4 2
Sample Output
17
Sample Explanation
将点1,2染黑就能获得最大收益。
Source
鸣谢bhiaibogf提供
Data Range
30%的数据:n <= 15。
60%的数据:n <= 100。
另20%的数据:这棵树是一条链。
100%的数据:0 <= m <= n <= 2000, 1 <= len <= 1000000。
【原题地址】
【30分】 暴搜 + 倍增求LCA
- 看到“30%的数据:n <= 15。”并且还是黑白点染色,很容易想到 O(215) 的暴搜
- 统一答案时我们要计算树上任意两点
x,y
间的距离,也很容易想到倍增求
LCA
,记
f[x]
表示从根节点到点
x
的距离,
x,y 的最近公共祖先为 z ,则两点距离为f[x]+f[y]−2×f[z]
【代码1】
#include <iostream>
#include <cstdio>
using namespace std;
const int N = 2005, M = 4005;
int lst[N], to[M], nxt[M], cst[M], dep[N], flw[N], fa[N][15];
int n, m, T, Ans, vis[N];
inline void addEdge(const int &x, const int &y, const int &z)
{
nxt[++T] = lst[x]; lst[x] = T; to[T] = y; cst[T] = z;
nxt[++T] = lst[y]; lst[y] = T; to[T] = x; cst[T] = z;
}
inline void Inite(const int &x, const int &fat)
{
dep[x] = dep[fat] + 1;
for (int i = 0; i < 12; ++i)
fa[x][i + 1] = fa[fa[x][i]][i];
for (int i = lst[x]; i; i = nxt[i])
{
int y = to[i];
if (y == fat) continue;
fa[y][0] = x;
flw[y] = flw[x] + cst[i];
Inite(y, x);
}
}
inline void Swap(int &a, int &b) {int t = a; a = b; b = t;}
inline int Query(int x, int y)
{
if (dep[x] < dep[y]) Swap(x, y);
for (int i = 12; i >= 0; --i)
{
if (dep[fa[x][i]] >= dep[y]) x = fa[x][i];
if (x == y) return x;
}
for (int i = 12; i >= 0; --i)
if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
return fa[x][0];
}
inline int CkDist(const int &x, const int &y)
{
int z = Query(x, y);
return flw[x] + flw[y] - (flw[z] << 1);
}
inline void Dfs(const int &k, const int &cnt)
{
if (cnt > m || n - k + 1 + cnt < m) return ;
if (k > n)
{
int res = 0;
if (cnt == m)
{
for (int i = 1; i <= n; ++i)
for (int j = i + 1; j <= n; ++j)
if (vis[i] == vis[j]) res += CkDist(i, j);
if (Ans < res) Ans = res;
}
return ;
}
for (int i = 0; i <= 1; ++i)
{
vis[k] = i;
Dfs(k + 1, cnt + i);
vis[k] = 0;
}
}
int main()
{
scanf("%d%d", &n, &m); int x, y, z;
for (int i = 1; i < n; ++i)
{
scanf("%d%d%d", &x, &y, &z);
addEdge(x, y, z);
}
Inite(1, 0);
Dfs(1, 0); printf("%d\n", Ans);
fclose(stdin); fclose(stdout);
return 0;
}
【100分】 树形DP
- 记
f[x][j]
表示以点
x
为根,选择
j 个黑点,对答案的最大贡献 - 为什么是说“对答案的最大贡献呢”?首先我们是从
x
的子节点
y 转移过来,那么主要是想到一点,即分别考虑每条边 x→y 对答案的贡献,也就是“边一侧的黑点数 × 另一侧的黑点数 × 边的长度 + 边一侧的白点数 × 另一侧的白点数 × 边的长度” - 我们记
sze[x]
表示以
x
为根的子树大小,边
x→y 对答案的贡献为 valx→y ,则状态转移方程为:f[x][j]=Max(f[x][j−k]+f[y][k]+valx→y)=Max(f[x][j−k]+f[y][k]+k×(j−k)×lenx→y+(sze[y]−k)×(n−m−sze[y]+k)×lenx→y) - 看起来我们要枚举
x→y,k,j
,复杂度为
O(n3)
,实际上我们会发现枚举的范围为
0≤k≤sze[y],0≤j−k≤sze[x]−sze[y]
,那么总的枚举次数就为
∑(sze[x]−sze[y])×sze[y]
,我们不妨把其看作每次在以某一点
x
为根的子树上任选两个点
a,b ,使得 a,b 的最近公共祖先为 x 的方案数,因此实际复杂度为O(n2)
【代码2】
#include <iostream>
#include <cstdio>
using namespace std;
typedef long long ll;
const int N = 2005;
int sze[N], n, m; ll f[N][N];
struct Edge
{
int to, len; Edge *nxt;
}p[N << 1], *T = p, *lst[N];
inline void addEdge(const int &x, const int &y, const int &z)
{
(++T)->nxt = lst[x]; lst[x] = T; T->to = y; T->len = z;
(++T)->nxt = lst[y]; lst[y] = T; T->to = x; T->len = z;
}
inline int Min(const int &x, const int &y) {return x < y ? x : y;}
inline void CkMax(ll &x, const ll &y) {if (x < y) x = y;}
inline void Dfs(const int &x, const int &fa)
{
int y; ll z; sze[x] = 1;
for (Edge *e = lst[x]; e; e = e->nxt)
{
if ((y = e->to) == fa) continue;
Dfs(y, x); sze[x] += sze[y];
}
for (int i = 2; i <= sze[x]; ++i) f[x][i] = -1ll;
f[x][0] = f[x][1] = 0ll;
for (Edge *e = lst[x]; e; e = e->nxt)
{
if ((y = e->to) == fa) continue; z = e->len;
int tx = Min(m, sze[x]);
for (int j = tx; j >= 0; --j)
{
int ty = Min(j, sze[y]);
for (int k = 0; k <= ty && j >= k; ++k)
//为什么这里j、k要分别正着和倒着枚举?
//回忆01背包:这里我们显然是不能重复选取相同的黑点数来转移的,这样会算重
//则我们要使j尽量大,优先转移来避免算重
//而当k = 0时:f[x][j - k = j]
//那么若不先处理k = 0,而之后转移的话,f[x][j]的值已经改变,我们同样会重复计算
if (f[x][j - k] != -1ll)
CkMax(f[x][j], f[x][j - k] + f[y][k]
+ (ll)k * (m - k) * z + (ll)(sze[y] - k) * (n - m - sze[y] + k) * z);
}
}
}
inline int get()
{
char ch; int res;
while ((ch = getchar()) < '0' || ch > '9');
res = ch - '0';
while ((ch = getchar()) >= '0' && ch <= '9')
res = (res << 3) + (res << 1) + ch - '0';
return res;
}
int main()
{
freopen("coloration.in", "r", stdin);
freopen("coloration.out", "w", stdout);
scanf("%d%d", &n, &m); int x, y, z;
for (int i = 1; i < n; ++i)
{
x = get(); y = get(); z = get();
addEdge(x, y, z);
}
Dfs(1, 0);
cout << f[1][m] << endl;
fclose(stdin); fclose(stdout);
return 0;
}