题意:已知树上的若干链< u,v>,每条链有一个权值w,求链不相交的最大权值和集合。
解法:
1.以1为根建立有根树
2.寻找每点的dfs出入序
3.将链按照其端点lca的不同进行分类。dp[u]:以u为跟的子树链不相交最大权值和。sum[u]:u的子节点的dp之和。
(1).没有以u为lca的链。dp[u]=sum[u]
(2).有以u为lca的链。dp[u]=max(dp[u],sum[u]+Sum(sum[v]-dp[v])+quan[id] | v取遍该链上的所有点(不包括u本身),id为当前链的编号)
该公式仔细想一下就能得到,思想是先加后减
本题可得到一个实用的方法结论
1.动态改变树中点权,多次询问链
#pragma comment(linker,"/STACK:102400000,102400000")
#include<stdio.h>
#include<string.h>
#include<iostream>
#include<vector>
#include<math.h>
#include<algorithm>
#define ll int
using namespace std;
const int maxn = 100000+10;
int t,n,m;
vector<int> g[maxn];
int fa[maxn][25],deep[maxn];
int cnt,l[maxn],r[maxn];
vector<int> lis[maxn];
void dfs(int u,int f,int d){
l[u]=++cnt,deep[u]=d,fa[u][0]=f;
for(int i=0;i<g[u].size();i++){
int v=g[u][i];
if(v==f) continue;
dfs(v,u,d+1);
}
r[u]=++cnt;
}
void init(){
cnt=0; dfs(1,0,1);
for(int i=1;i<=20;i++) for(int j=1;j<=n;j++) fa[j][i]=fa[fa[j][i-1]][i-1];
for(int i=1;i<=n;i++) lis[i].clear();
}
int lca(int x, int y) {
if (deep[x] < deep[y]) swap(x, y);
int delta = deep[x] - deep[y];
for(int j=0;j<=20;j++) if ((1<<j) & delta) x = fa[x][j];
if (x == y) return x;
for(int j=20;j>=0;j--) if (fa[x][j] != fa[y][j]) x = fa[x][j], y = fa[y][j];
return fa[x][0];
}
int st[maxn],en[maxn],quan[maxn];
int dp[maxn],sum[maxn],sd[maxn<<1],ss[maxn<<1];
int lowbit(int x) { return x&(-x); }
void Add(int id,int x,int c[]) { for(int i=id;i<=2*n;i+=lowbit(i)) c[i]+=x; }
int Sum(int id,int c[]) { int tmp=0; for(int i=id;i>0;i-=lowbit(i)) tmp+=c[i]; return tmp; }
void dfs2(int u,int f){
for(int i=0;i<g[u].size();i++){
int v=g[u][i];
if(v==f) continue;
dfs2(v,u);
sum[u]+=dp[v];
}
dp[u]=sum[u];
for(int i=0;i<lis[u].size();i++){
int id=lis[u][i],s=st[id],t=en[id],w=quan[id];
int tmp=Sum(l[s],ss)+Sum(l[t],ss)-Sum(l[s],sd)-Sum(l[t],sd)+sum[u];//Sum(l[s],ss)即为s到达root的sum函数和
//由于dfs的更新过程是从下往上的,所以在更新u节点时,u及u的祖先节点都是没有更新的,因此求和后只是s到t路径上的权值和
dp[u]=max(dp[u],tmp+w);
}
Add(l[u],dp[u],sd),Add(r[u],-dp[u],sd),Add(l[u],sum[u],ss),Add(r[u],-sum[u],ss);
}
int main(){
//freopen("a.txt","r",stdin);
scanf("%d",&t);
while(t--){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++) g[i].clear();
for(int i=1;i<n;i++) { int u,v; scanf("%d%d",&u,&v); g[u].push_back(v); g[v].push_back(u); }
init();
for(int i=1;i<=m;i++){
scanf("%d%d%d",&st[i],&en[i],&quan[i]);
int tmp=lca(st[i],en[i]); lis[tmp].push_back(i);
}
memset(dp,0,sizeof(dp)); memset(sum,0,sizeof(sum)); memset(sd,0,sizeof(sd)); memset(ss,0,sizeof(ss));
dfs2(1,0);
printf("%d\n",dp[1]);
}
return 0;
}