题目描述:
大意:给出树上的一些带权值的链,选出一些点不相交的链,使得它们的权值和最大。 n , m ≤ 2 ∗ 1 0 5 n,m\le2*10^5 n,m≤2∗105
题目分析:
问题就是怎么求删掉这条链后的极大子树的DP值。
大力数据结构可以重链剖分后在每个点上存下轻儿子的DP值之和,然后往上跳重链,每次加上这条重链上存的值(单点修改区间求和),减去上一次跳上来的轻儿子,加上当前点的重儿子。复杂度 O ( n l o g 2 n ) O(nlog^2n) O(nlog2n)
对于这道题还有复杂度更低的做法:记
s
[
x
]
s[x]
s[x]表示
x
x
x的所有儿子的DP值之和,那么对于一条以
u
u
u为LCA,链端点为
x
,
y
x,y
x,y的链,它对
f
[
u
]
f[u]
f[u]的贡献就是这条链上的
s
s
s值之和减去链上除了
u
u
u点的
f
f
f值。
记
g
[
x
]
=
s
[
x
]
−
f
[
x
]
g[x]=s[x]-f[x]
g[x]=s[x]−f[x],贡献变为链上除了
u
u
u点的
g
g
g值之和加上
u
u
u的所有儿子的
f
f
f值之和。
由于DP是从下往上的,
g
[
u
]
g[u]
g[u]一定会在链端点在它的子树中时被统计,所以给
u
u
u的子树直接加上
g
[
u
]
g[u]
g[u],查询时直接查询链端点处的值即可(区间修改单点查询),复杂度
O
(
n
l
o
g
n
)
O(nlogn)
O(nlogn)。
上面这种做法,在最后一步可以将子树加换成并查集缩链并维护权值,复杂度更低,但是在换父亲时有一点实现上的细节,详见这篇博客
Code:
#include<bits/stdc++.h>
#define maxn 200005
using namespace std;
char cb[1<<20],*cs,*ct;
#define getc() (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<20,stdin),cs==ct)?0:*cs++)
template<class T>inline void read(T &a){
char c;while(!isdigit(c=getc()));
for(a=c-'0';isdigit(c=getc());a=a*10+c-'0');
}
const int Log = 17;
int n,m,F[Log+1][maxn],dep[maxn],in[maxn],out[maxn],tim,f[maxn];
int arr[maxn],fir[maxn],nxt[maxn];
inline void line(int x,int y){nxt[y]=fir[x],fir[x]=y;}
struct node{int x,y,w;};
vector<node>q[maxn];
inline int LCA(int u,int v){
if(dep[u]<dep[v]) swap(u,v);
for(int i=0,d=dep[u]-dep[v];d;d>>=1,i++) if(d&1) u=F[i][u];
if(u==v) return u;
for(int i=Log;i>=0;i--) if(F[i][u]!=F[i][v]) u=F[i][u],v=F[i][v];
return F[0][u];
}
void dfs(int u){in[u]=++tim;for(int i=fir[u];i;i=nxt[i]) dep[i]=dep[u]+1,dfs(i);out[u]=tim;}
inline void upd(int i,int v){for(;i<=n;i+=i&-i) arr[i]+=v;}
inline int qsum(int i){int s=0;for(;i;i-=i&-i) s+=arr[i];return s;}
int main()
{
freopen("1.in","r",stdin);
int x,y,w;
read(n),read(m);
for(int i=2;i<=n;i++) read(F[0][i]),line(F[0][i],i);
for(int j=1;j<=Log;j++)
for(int i=1;i<=n;i++)
F[j][i]=F[j-1][F[j-1][i]];
dfs(1);
while(m--){
read(x),read(y),read(w);
q[LCA(x,y)].push_back((node){x,y,w});
}
for(int u=n;u>=1;u--){
for(node t:q[u]) f[u]=max(f[u],qsum(in[t.x])+qsum(in[t.y])+t.w);
int s=0;
for(int v=fir[u];v;v=nxt[v]) s+=f[v];
f[u]+=s;
upd(in[u],s-f[u]),upd(out[u]+1,f[u]-s);
}
printf("%d\n",f[1]);
}