题意:给一棵 n n n 个点的有点权的树,你需要找 k k k 条根到叶子的路径,使得路径并集的权值和最大。
n ≤ 2 × 1 0 5 n\leq 2\times 10^5 n≤2×105
其实就是个贪心,只是从这个角度更自然一点(
先有个显然的 dp,设 f ( u , k ) f(u,k) f(u,k) 为从 u u u 往下找 k k k 条链覆盖的权值最大值。
f ′ ( u , k ) = max i = 0 k { f ( u , i ) + f ( v , k − i ) } + [ k > 0 ] a u f'(u,k)=\max_{i=0}^k \{f(u,i)+f(v,k-i)\}+[k>0]a_u f′(u,k)=i=0maxk{f(u,i)+f(v,k−i)}+[k>0]au
发现是个闵可夫斯基和的形式。
然后往这个方向考虑,不管怎么理解都能看出这是个凸壳。
所以合并的时候继承叶子最多的儿子(或者直接继承子树最大的也可以)的凸包,其他的暴力做闵可夫斯基和。可以用堆来维护斜率也就是差分,然后暴力插入。最后把第一项 + a u +a_u +au 即可。
复杂度是 O ( n log 2 n ) \Omicron(n\log ^2n) O(nlog2n),但常数很小。
脑抽写了个平衡树,我是 SB
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <vector>
#include <cstdlib>
#define MAXN 200005
using namespace std;
inline int read()
{
int ans=0;
char c=getchar();
while (!isdigit(c)) c=getchar();
while (isdigit(c)) ans=(ans<<3)+(ans<<1)+(c^48),c=getchar();
return ans;
}
typedef long long ll;
int ch[MAXN][2],key[MAXN],tot;
ll val[MAXN];
inline int newnode(int v){return val[++tot]=v,key[tot]=rand(),tot;}
int merge(int x,int y)
{
if (!x||!y) return x|y;
if (key[x]<key[y]) return ch[x][1]=merge(ch[x][1],y),x;
return ch[y][0]=merge(x,ch[y][0]),y;
}
void split(int x,ll v,int& l,int& r)
{
if (!x) return (void)(l=r=0);
if (val[x]>v) split(ch[x][1],v,l,r),ch[x][1]=l,l=x;
else split(ch[x][0],v,l,r),ch[x][0]=r,r=x;
}
int getfir(int& x)
{
if (!ch[x][0])
{
int t=x;
x=ch[x][1],ch[t][1]=0;
return t;
}
return getfir(ch[x][0]);
}
inline void insert(int& x,int v)
{
int l,r;
split(x,val[v],l,r);
x=merge(merge(l,v),r);
}
vector<int> e[MAXN];
int a[MAXN],rt[MAXN],lev[MAXN],son[MAXN];
void dfs(int u)
{
if (e[u].empty()) return (void)(lev[u]=1);
for (int i=0;i<(int)e[u].size();i++)
{
dfs(e[u][i]);
if (lev[e[u][i]]>lev[son[u]]) son[u]=e[u][i];
lev[u]+=lev[e[u][i]];
}
}
void Dfs(int u)
{
if (son[u]) Dfs(son[u]),rt[u]=rt[son[u]];
else return (void)(rt[u]=newnode(a[u]));
for (int i=0;i<(int)e[u].size();i++)
if (e[u][i]!=son[u])
{
Dfs(e[u][i]);
while (rt[e[u][i]]) insert(rt[u],getfir(rt[e[u][i]]));
}
int p=getfir(rt[u]);
val[p]+=a[u];
rt[u]=merge(p,rt[u]);
}
int main()
{
int n,k;
n=read(),k=read();
for (int i=1;i<=n;i++) a[i]=read();
for (int i=1;i<n;i++)
{
int u,v;
u=read(),v=read();
e[u].push_back(v);
}
dfs(1),Dfs(1);
ll ans=0;
while (k--) ans+=val[getfir(rt[1])];
cout<<ans;
return 0;
}