题目连接
题意:给一颗带边权的树,你要求一个最大权联通快,要求里面超过k度的节点只能有一个。
题解:
这个题解不是官方思路。而是一位大佬的思路。
首先,我们dp[i][2]表示第i个节点为根的子树。有没有与fa向连的最大权值。我们不妨先考虑,没有大于k度的节点,然后再对于每一个节点,令它的度数大于k,从而求出最大权。
显然对于dp[i][0],他最多可以连接k个孩子,对于dp[i][1],他最多可以连接k-1个孩子。我们首先要择优选取。我们首先递归孩子,得出孩子节点的最大权值,再以(孩子的最大权值+这条边的权值)从大到小排序。dp[i][0] 等于前k个孩子和边权之和。dp[i][1]即为前k-1个。然后我们思考枚举其他的边。这里从上往下枚举。累计更换某个边后的权值之和。变更边时,应该从已选边里找一个最少的,然后加新边。递归结束后,我们让这个节点为度数大于k的点,所以我们直接吧他所有的孩子再累计上即可。
#include<cstdio>
#include<iostream>
#include<cstring>
#include <map>
#include <queue>
#include <set>
#include <cstdlib>
#include <cmath>
#include <algorithm>
#include <vector>
#include <string>
#include <list>
#include <bitset>
#include <array>
#include <cctype>
#include <time.h>
#pragma GCC optimize(2)
void read_f() { freopen("1.in", "r", stdin); freopen("1.out", "w", stdout); }
void fast_cin() { std::ios::sync_with_stdio(false); std::cin.tie(); }
void run_time() { std::cout << "ESC in : " << clock() * 1000.0 / CLOCKS_PER_SEC << "ms" << std::endl; }
template <typename T>
bool bacmp(const T & a, const T & b) { return a > b; }
template <typename T>
bool pecmp(const T & a, const T & b) { return a < b; }
#define ll long long
#define ull unsigned ll
#define _min(x, y) ((x)>(y)?(y):(x))
#define _max(x, y) ((x)>(y)?(x):(y))
#define max3(x, y, z) ( max( (x), max( (y), (z) ) ) )
#define min3(x, y, z) ( min( (x), min( (y), (z) ) ) )
#define pr(x, y) (make_pair((x), (y)))
#define pb(x) push_back(x);
using namespace std;
const int N = 5e5+5;
ll dp[N][2], ans;
int n, k;
vector< pair<ll, int> > g[N], v[N];
void init(int n)
{
for (int i = 1; i <= n; i++)
{
g[i].clear();
v[i].clear();
dp[i][0] = dp[i][1] = 0;
}
ans = 0;
}
void dfs(int x, int fa)
{
int sz = g[x].size();
for (auto i : g[x])
{
int y = i.second;
ll w = i.first;
if (y == fa) continue;
dfs(y, x);
v[x].pb( pr(dp[y][0] + w, y) );
}
sort(v[x].begin(), v[x].end(), bacmp<pair<ll, int> >);
for (int i = 0; i < min((int)v[x].size(), k-1 + (x==1) ); i++)
dp[x][0] += v[x][i].first;
for (int i = 0; i < min((int)v[x].size(), k); i++)
dp[x][1] += v[x][i].first;
}
void dfs2(int x, int fa, ll s)
{
for (int i = 0; i < v[x].size(); i++)
{
int y = v[x][i].second;
ll w = v[x][i].first;
if (x == 1)
{
if (i >= k)
dfs2(y, x, s+dp[x][0] - v[x][k-1].first + w - dp[y][0]);
else
dfs2(y, x, s+dp[x][0] - dp[y][0]);
}
else
{
if (i >= k)
dfs2(y, x, max(s+dp[x][0]-v[x][k-2].first + w -dp[y][0], dp[x][1] - v[x][k-1].first + w - dp[y][0]));
else if (i == k - 1)
dfs2(y,x, max(s + dp[x][0] - v[x][k-2].first + w - dp[y][0], dp[x][1]-dp[y][0]));
else
dfs2(y,x,max(s + dp[x][0]-dp[y][0], dp[x][1] - dp[y][0]));
}
}
ll sum = 0;
for (int i = 0; i < v[x].size(); i++)
sum += v[x][i].first;
ans = max(ans, sum+s);
}
int main()
{
int t; cin >> t;
while(t--)
{
scanf("%d%d", &n, &k);
init(n);
for (int i = 1; i < n; i++)
{
int x, y; ll w;
scanf("%d%d%lld", &x, &y, &w);
g[x].pb(pr(w, y));
g[y].pb(pr(w, x));
}
if (k == 0) { puts("0"); continue; }
else if (k == 1)
{
for (int i = 1; i <= n; i++)
{
ll sum = 0;
for (auto j : g[i])
sum += j.first;
ans = max(ans, sum);
}
printf("%lld\n", ans);
continue;
}
dfs(1, 0); dfs2(1, 0, 0);
printf("%lld\n", ans);
}
}