题意:给出一棵有n个结点的树,树上有m条链,每条链有一个权值,选出最多没有重复节点的链,使得权值和最大。
思路:树形DP+树链剖分+LCA。首先处理出每条链两个端点的LCA并保存起来。
用dp[i]表示以i为根节点的子树的最大值,sum[i]表示i的儿子节点的dp值的和,即sigma(dp[j]),j是i的儿子节点。
这样状态转移方程也可以得出,当i节点不放链的话,dp[i] = sum[i]
当i节点放链的时候,观察树的特点(涂黑的结点是树链上的点),如下图所示:
所以我们可以dfs的过程中处理出每个节点的sum[i]-dp[i]的值,然后查询的时候在树状数组里查询即可。
#include<cstdio>
#include<cstring>
#include<cmath>
#include<cstdlib>
#include<iostream>
#include<algorithm>
#include<vector>
#include<map>
#include<queue>
#include<stack>
#include<string>
#include<map>
#include<set>
#include<ctime>
#define eps 1e-6
#define LL long long
#define pii pair<int, int>
#pragma comment(linker, "/STACK:1024000000,1024000000")
using namespace std;
const int MAXN = 100100;
//const int INF = 0x3f3f3f3f;
int n, m;
int pnt[MAXN];
bool vis[MAXN];
vector<int> G[MAXN], query[MAXN], W[MAXN];
struct Node {
int u, v, w;
Node(int u=0, int v=0, int w=0) : u(u), v(v), w(w) {}
};
vector<Node> chain[MAXN];
int Find(int x) {
if(x == pnt[x]) return x;
return pnt[x] = Find(pnt[x]);
}
void Tarjan(int cur) {
pnt[cur] = cur;
vis[cur] = 1;
for(int i = 0; i < G[cur].size(); i++) {
int v = G[cur][i];
if(vis[v]) continue;
Tarjan(v);
pnt[v] = cur;
}
for(int i = 0; i < query[cur].size(); i++) {
int v = query[cur][i];
if(vis[v]) {
int lca = Find(v);
chain[lca].push_back(Node(cur, v, W[cur][i]));
}
}
}
//以上是lca部分,以下是树链剖分
int C[MAXN];
int tot;
int siz[MAXN], son[MAXN], dep[MAXN], top[MAXN], fa[MAXN], pos[MAXN];
int lowbit(int x) {
return x & -x;
}
void add(int x, int d) {
while(x <= n) {
C[x] += d;
x += lowbit(x);
}
}
int Sum(int x) {
int ret = 0;
while(x > 0) {
ret += C[x];
x -= lowbit(x);
}
return ret;
}
void dfs(int cur, int f) {
siz[cur] = 1;
int tmp = 0;
for(int i = 0; i < G[cur].size(); i++) {
int u = G[cur][i];
if(u == f) continue;
dep[u] = dep[cur] + 1;
fa[u] = cur;
dfs(u, cur);
siz[cur] += siz[u];
if(siz[u] > tmp) son[cur] = u, tmp = siz[u];
}
}
void dfs2(int cur, int tp) {
top[cur] = tp;
pos[cur] = ++tot;
if(son[cur]) dfs2(son[cur], tp);
for(int i = 0; i < G[cur].size(); i++) {
int u = G[cur][i];
if(u==son[cur] || u==fa[cur]) continue;
dfs2(u, u);
}
}
int cal(int u, int v) {
int ans = 0;
int fu = top[u];
while(dep[fu] > dep[v]) {
ans += Sum(pos[u])-Sum(pos[fu]-1);
u = fa[fu]; fu = top[u];
}
if(dep[u] > dep[v]) ans += Sum(pos[u]) - Sum(pos[v]);
return ans;
}
//以下是树形DP
int sumv[MAXN], dp[MAXN];
void dfs3(int cur, int fa) {
sumv[cur] = 0;
for(int i = 0; i < G[cur].size(); i++) {
int u = G[cur][i];
if(u == fa) continue;
dfs3(u, cur);
sumv[cur] += dp[u];
}
dp[cur] = sumv[cur];
for(int i = 0; i < chain[cur].size(); i++) {
Node t = chain[cur][i];
int tmp = cal(t.u, cur) + cal(t.v, cur) + sumv[cur] + t.w;
dp[cur] = max(dp[cur], tmp);
}
add(pos[cur], sumv[cur]-dp[cur]);
}
void init() {
memset(vis, 0, sizeof(vis));
memset(C, 0, sizeof(C));
memset(son, 0, sizeof(son));
tot = 0;
for(int i = 1; i <= n; i++) {
G[i].clear();
W[i].clear();
query[i].clear();
chain[i].clear();
}
}
int main() {
//freopen("input.txt", "r", stdin);
int T; cin >> T;
while(T--) {
scanf("%d%d", &n, &m);
init();
for(int i = 1, u, v; i < n; i++) {
scanf("%d%d", &u, &v);
G[u].push_back(v);
G[v].push_back(u);
}
for(int i = 1, u, v, d; i <= m; i++) {
scanf("%d%d%d", &u, &v, &d);
query[u].push_back(v);
query[v].push_back(u);
W[u].push_back(d);
W[v].push_back(d);
}
Tarjan(1);
dfs(1, 0);
dfs2(1, 1);
dfs3(1, 0);
cout << dp[1] << endl;
}
return 0;
}