考试题的Path的升级版,不是求最多的路径数量,而是求选哪些路径的价值和最大,这样就没有办法贪心了。
知识点:
1、很明显的能看出来这是一道树形dp题。
2、每个点之间的关系要和路径连上关系,树上的路径是固定的,那就有一个lca。
3、为了计算dp值,要求路径上的点的值,又要更新,那就要用到线段树。
4、树上的线段树,又要用到dfs序。
然后分条来讲
1、树形dp这就没法讲了,看不出来也没有办法
2、树上的路径都有两个端点,又要求路径上的每个点的相关数据,那就要知道一个lca,方便计算
3、dp的定义是以i为根时权值最大为dp[i],还要一个sum[i],用来计算i的每个子节点的dp值的和,转移时dp[i]=max(sum[i],sum[i]+ALLsum[k]-ALLdp[k]),其中k是以i为lca的一条路径的节点。
这样讲肯定看不懂,可以这么理解。
现在有一条路径,如果要选这条路径,那么这条路径上的所有点不能作为其他路径上的点,那么就要收集路径上的sum和dp,用sum-dp就行了。dp是要减掉路径上的点的可能性。
4、dfs序在这里简单,只要记录一下L、R就行。线段树里装的是这个点到1点的路径上的sum值得和与dp值得和。那这棵子树上就是Query(x)+Query(y)-Query(lca)*2 就行了。
还要注意只有当所有子节点都算出来后才能算父节点。
#include<bits/stdc++.h>
#define M 100005
using namespace std;
struct node1{int x,y,lca,z;};
vector<node1>G[M];
vector<int>edge[M];
int fa[18][M],dep[M],n,m,dp[M],sum[M];
int R[M],L[M],T;
struct node2{int L,R,s,d;}tree[4*M];
void Build(int L,int R,int p){//以下的Build,Query,Change都是线段树,写的有点麻烦,大致是单点查询,区间更新
tree[p].L=L;tree[p].R=R;
tree[p].s=tree[p].d=0;
if(L==R){
return;
}
int mid=(L+R)/2;
Build(L,mid,p*2);Build(mid+1,R,p*2+1);
}
int Querys(int L,int R,int p){
if(tree[p].L==L&&tree[p].R==R)return tree[p].s;
int mid=(tree[p].L+tree[p].R)/2;
if(mid>=R)return Querys(L,R,2*p)+tree[p].s;
else if(mid<L)return Querys(L,R,2*p+1)+tree[p].s;
else return Querys(L,mid,2*p)+Querys(mid+1,R,2*p+1)+tree[p].s;
}
int Queryd(int L,int R,int p){
if(tree[p].L==L&&tree[p].R==R)return tree[p].d;
int mid=(tree[p].L+tree[p].R)/2;
if(mid>=R)return Queryd(L,R,2*p)+tree[p].d;
else if(mid<L)return Queryd(L,R,2*p+1)+tree[p].d;
else return Queryd(L,mid,2*p)+Queryd(mid+1,R,2*p+1)+tree[p].d;
}
void Change(int L,int R,int p,int sumk,int dpk){
if(tree[p].L==L&&tree[p].R==R){
tree[p].s+=sumk;tree[p].d+=dpk;
return;
}
int mid=(tree[p].L+tree[p].R)/2;
if(mid>=R)Change(L,R,2*p,sumk,dpk);
else if(mid<L)Change(L,R,2*p+1,sumk,dpk);
else Change(L,mid,2*p,sumk,dpk),Change(mid+1,R,2*p+1,sumk,dpk);
}
void f(int x,int fa1){//造树,dfs序
fa[0][x]=fa1;dep[x]=dep[fa1]+1;
L[x]=++T;
for(int i=0;i<(int)edge[x].size();i++){
int y=edge[x][i];
if(y==fa1)continue;
f(y,x);
}
R[x]=T;
}
void Init(){//LCA的预处理
for(int j=1;j<18;j++)
for(int i=1;i<=n;i++)
fa[j][i]=fa[j-1][fa[j-1][i]];
}
int LCA(int x,int y){//求两点LCA
if(dep[x]>dep[y])swap(x,y);
int step=dep[y]-dep[x];
for(int i=0;i<18;i++)
if(step&(1<<i))y=fa[i][y];
if(x==y)return x;
for(int i=17;i>=0;i--)
if(fa[i][x]!=fa[i][y])
x=fa[i][x],y=fa[i][y];
return fa[0][x];
}
void f1(int x,int fa1){
for(int i=0;i<(int)edge[x].size();i++){//先算完所有字点
int y=edge[x][i];
if(y==fa1)continue;
f1(y,x);
sum[x]+=dp[y];
}
dp[x]=sum[x];
for(int i=0;i<(int)G[x].size();i++){//用所有lca为x的路径来更新
int X=G[x][i].x,Y=G[x][i].y;
int sumk=Querys(L[X],L[X],1)+Querys(L[Y],L[Y],1);
int dpk=Queryd(L[X],L[X],1)+Queryd(L[Y],L[Y],1);
dp[x]=max(dp[x],sum[x]+G[x][i].z+sumk-dpk);
}
Change(L[x],R[x],1,sum[x],dp[x]);//算完更新
}
void solve(){
scanf("%d%d",&n,&m);
for(int i=1;i<n;i++){
int x,y;
scanf("%d%d",&x,&y);
edge[x].push_back(y);
edge[y].push_back(x);
}
f(1,0);Init();
for(int i=1;i<=m;i++){
int x,y,z;
scanf("%d%d%d",&x,&y,&z);
int lca=LCA(x,y);
G[lca].push_back((node1){x,y,lca,z});//按照lca放
}
Build(1,n,1);
f1(1,0);
printf("%d\n",dp[1]);
}
void Clear(){//多组数据,要清空
T=0;
memset(fa,0,sizeof fa);
memset(dep,0,sizeof dep);
memset(dp,0,sizeof dp);
memset(sum,0,sizeof sum);
memset(L,0,sizeof L);
memset(R,0,sizeof R);
for(int i=1;i<=n;i++)edge[i].clear();
for(int i=1;i<=n;i++)G[i].clear();
}
int main(){
int cas;
scanf("%d",&cas);
while(cas--){
solve();
Clear();
}
return 0;
}