Description
给出一棵n个点的树,每个点有点权,要求在树上选出恰好m条点不相交的链,每条链至少有k个点,要求点权和最大。
n ≤ 150005 n\leq 150005 n≤150005
Solution
看到恰好选出m条链,立刻反应到是凸优化,二分选择一条链需要花费的额外代价
但是要求每条链至少有K个点。
设
f
[
i
]
f[i]
f[i]为i子树中随便选若干条链的答案(满足k的限制)
一个点的贡献大概是
v
a
l
u
e
[
i
]
−
f
[
i
]
+
∑
p
∈
s
o
n
[
i
]
f
[
p
]
value[i]-f[i]+\sum\limits_{p\in son[i]}f[p]
value[i]−f[i]+p∈son[i]∑f[p],然后链上求和,再求一个最大值
我们考虑在链的LCA处做,容易想到启发式合并,枚举小的儿子,然后大的儿子采用数据结构的一类的维护。
这样的复杂度是 O ( n log 2 n ) O(n\log ^2 n) O(nlog2n)的,加上凸优化的复杂度就是 O ( n log 3 n ) O(n\log ^3n) O(nlog3n),明显不能通过
我们进一步思考,启发式合并的时候暴力枚举轻儿子的每个节点非常浪费,我们只需要知道每个深度的答案。
基于这个思想,对于每个点维护子树中每个深度的答案。
但是还是采用启发式合并(即轻重链剖分)直接合并的话复杂度好像还是不对
然而我们有长链剖分!
长链剖分就是用来解决枚举深度的问题的,每次枚举所有除最长链以外的其他儿子的每个深度,可以证明这样的时间复杂度是
O
(
n
)
O(n)
O(n)的
考虑我们需要维护后缀和+前缀max,可以用一个vector存储,枚举深度的时候顺便修改。
实现细节很多。。。。
Code
代码极丑
#include <cstdio>
#include <iostream>
#include <cstdlib>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <vector>
#define fo(i,a,b) for(int i=a;i<=b;++i)
#define fod(i,a,b) for(int i=a;i>=b;--i)
#define N 150005
#define LL long long
using namespace std;
int n,m,li,fs[N],nt[2*N],dt[2*N],m1,dst[N],son[N],n1,dep[N],rt[N];
LL pr[N],ans,mid,f[N],g[N],tg[N][2],s1;
vector<LL> pt[N][2];
void dfs(int k,int fa)
{
dep[k]=dep[fa]+1;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa)
{
dfs(p,k);
if(dst[p]+1>dst[k]) son[k]=p,dst[k]=dst[p]+1;
}
}
}
void upd(int k,int x,int p,int y)
{
if(pt[k][0][x]+tg[k][0]<pt[p][0][y]+tg[p][0]||(pt[k][0][x]+tg[k][0]==pt[p][0][y]+tg[p][0]&&pt[k][1][x]+tg[k][1]<pt[p][1][y]+tg[p][1]))
{
pt[k][0][x]=pt[p][0][y]+tg[p][0]-tg[k][0];
pt[k][1][x]=pt[p][1][y]+tg[p][1]-tg[k][1];
}
}
void dp(int k,int fa)
{
LL sv=pr[k],sl=0;
if(son[k]) dp(son[k],k),rt[k]=rt[son[k]],sv+=f[son[k]],sl+=g[son[k]];
else rt[k]=++n1,tg[n1][0]=tg[n1][1]=0,pt[n1][0].clear(),pt[n1][1].clear();
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa&&p!=son[k]) dp(p,k),sv+=f[p],sl+=g[p];
}
f[k]=0,g[k]=0;
if(sv-pr[k]>f[k]||(sv-pr[k]==f[k]&&sl>g[k])) f[k]=sv-pr[k],g[k]=sl;
for(int i=fs[k];i;i=nt[i])
{
int p=dt[i];
if(p!=fa&&p!=son[k])
{
fo(j,0,dst[p])
{
if(dst[k]+2+j<li) continue;
LL v=pt[rt[k]][0][dst[k]-1-max(0,li-j-3)]+tg[rt[k]][0]+sv+pt[rt[p]][0][dst[p]-j]-mid+tg[rt[p]][0];
LL v1=pt[rt[k]][1][dst[k]-1-max(0,li-j-3)]+tg[rt[k]][1]+sl+pt[rt[p]][1][dst[p]-j]+tg[rt[p]][1]+1;
if(v>f[k]||(f[k]==v&&v1>g[k]))
{
f[k]=v;
g[k]=v1;
}
}
fo(j,0,dst[p]) upd(rt[k],dst[k]-1-j,rt[p],dst[p]-j);
LL v=-1e17,vl=0;
if(dst[p]>=li-2) v=pt[rt[p]][0][dst[p]-max(0,li-2)]+tg[rt[p]][0]+sv-mid,vl=pt[rt[p]][1][dst[p]-max(0,li-2)]+tg[rt[p]][1]+sl+1;
if(li==1)
{
if(sv-mid>v||(sv-mid==v&&sl+1>vl)) v=sv-mid,vl=sl+1;
}
if(v>f[k]||(v==f[k]&&vl>g[k])) f[k]=v,g[k]=vl;
}
}
LL v=-1e17,vl=0;
if(dst[k]>=li-1&&dst[k]) v=pt[rt[k]][0][dst[k]-max(1,li-1)]+tg[rt[k]][0]+sv-mid,vl=pt[rt[k]][1][dst[k]-max(1,li-1)]+tg[rt[k]][1]+sl+1;
if(li==1)
{
if(sv-mid>v||(sv-mid==v&&sl+1>vl)) v=sv-mid,vl=sl+1;
}
if(v>f[k]||(v==f[k]&&vl>g[k])) f[k]=v,g[k]=vl;
if(f[k]>ans||(f[k]==ans&&g[k]>s1)) ans=f[k],s1=g[k];
sv-=f[k],sl-=g[k];
tg[rt[k]][0]+=sv,tg[rt[k]][1]+=sl;
if(dst[k]>0&&(pt[rt[k]][0][dst[k]-1]+tg[rt[k]][0]>sv||(pt[rt[k]][0][dst[k]-1]+tg[rt[k]][0]==sv&&pt[rt[k]][1][dst[k]-1]+tg[rt[k]][1]>sl))) pt[rt[k]][0].push_back(pt[rt[k]][0][dst[k]-1]),pt[rt[k]][1].push_back(pt[rt[k]][1][dst[k]-1]);
else pt[rt[k]][0].push_back(sv-tg[rt[k]][0]),pt[rt[k]][1].push_back(sl-tg[rt[k]][1]);
}
void link(int x,int y)
{
nt[++m1]=fs[x];
dt[fs[x]=m1]=y;
}
int main()
{
cin>>n>>m>>li;
LL l=-1e13,r=1e13;
fo(i,1,n)
{
scanf("%lld",&pr[i]);
}
fo(i,1,n-1)
{
int x,y;
scanf("%d%d",&x,&y);
link(x,y),link(y,x);
}
dfs(1,0);
while(l+1<r)
{
mid=(l+r)/2;
ans=s1=0;
n1=0;
dp(1,0);
if(s1==m)
{
printf("%lld\n",(LL)(ans+mid*(LL)m));
return 0;
}
if(s1<m) r=mid;
else l=mid;
}
mid=r;n1=0;ans=s1=0;dp(1,0);
if(s1==m)
{
printf("%lld\n",(LL)(ans+mid*(LL)m));
return 0;
}
else
{
mid=l;n1=0;ans=s1=0;dp(1,0);
printf("%lld\n",(LL)(ans+mid*(LL)m));
}
}