题意:
给你一棵 n n n个点的树,你需要在树上选择恰好 m m m条点不相交的、长度至少为 k k k的路径,使得路径所覆盖的点权和尽可能大。求最大点权和。数据保证有解。
数据范围:
100 100 100%的数据: n , m , k ≤ 1.5 ∗ 1 0 5 n,m,k \leq 1.5*10^5 n,m,k≤1.5∗105
Analysis:
我们考虑一个暴力的
D
P
DP
DP:
f
i
,
j
,
k
f_{i,j,k}
fi,j,k,表示当前
D
P
DP
DP到第
i
i
i个节点,选了
j
j
j条链,并且伸出去一条长度为
k
k
k的链的最大权值和。
发现
j
j
j每次增大的值逐渐变小,因为我们每一次肯定选一条权值最大的链加进去,它的函数是凸的,我们就可以凸优化,二分一条直线去切这个凸壳,被切到的地方之后都不优了。
那么我们就可以去掉一维状态,为了方便转移,我们将状态重新定义为:
f
i
,
j
f_{i,j}
fi,j表示
D
P
DP
DP到
i
i
i节点,伸出去一条长度至少为
j
j
j的链,并且选择了若干条链的最大权值和。
然后发现第二维被深度限制,用长链剖分优化转移即可。
复杂度
O
(
n
log
值
域
)
O(n\log 值域)
O(nlog值域)。这题想和写是两个世界。
Code:
# include<cstdio>
# include<cstring>
# include<algorithm>
using namespace std;
const int N = 2e5 + 5;
typedef long long ll;
const ll inf = 1e7;
struct node
{
ll v; int k;
bool operator < (node r) const
{ return v == r.v ? k > r.k : v < r.v; }
node operator + (node r) const
{ return (node){v + r.v,k + r.k}; }
node operator - (node r) const
{ return (node){v - r.v,k - r.k}; }
};
int st[N],to[N << 1],nx[N << 1],w[N];
int h[N],son[N];
node c[N << 3],g[N],z[N];
node *f[N],*now;
int n,m,k,tot;
inline void add(int u,int v)
{
to[++tot] = v,nx[tot] = st[u],st[u] = tot;
to[++tot] = u,nx[tot] = st[v],st[v] = tot;
}
inline void newc(int x) { f[x] = now,now = now + 2 * h[x] + 1; }
inline void pre(int x,int F)
{
for (int i = st[x] ; i ; i = nx[i])
if (to[i] != F)
{ pre(to[i],x); if (h[to[i]] > h[son[x]]) son[x] = to[i]; }
h[x] = h[son[x]] + 1;
}
inline void dfs(int x,int F,ll s)
{
g[x] = z[x] = (node){0,0}; node is = (node){0,0};
if (son[x])
{
f[son[x]] = f[x] + 1; dfs(son[x],x,s);
z[x] = z[son[x]],z[x].v += w[x]; if (g[son[x]].v >= 0) g[x] = g[son[x]],is = g[son[x]];
}
if (!son[x]) f[x][0] = (node){w[x] - s,1}; else f[x][0] = max(((node){w[x] - s,1} - z[x]) + is,f[x][1]);
for (int i = st[x] ; i ; i = nx[i])
if (to[i] != F && to[i] != son[x])
{
newc(to[i]),dfs(to[i],x,s);
if (g[to[i]].v >= 0) g[x] = g[x] + g[to[i]],z[x] = z[x] + g[to[i]];
}
node S = g[x];
for (int i = st[x] ; i ; i = nx[i])
if (to[i] != F && to[i] != son[x])
{
node now = (node){0,0}; if (g[to[i]].v >= 0) now = g[to[i]];
for (int j = 0 ; j < min(h[to[i]] + 1,k) ; ++j)
{
int nx = max(k - j - 2,0); if (nx > h[x]) continue;
f[to[i]][j] = f[to[i]][j] + z[to[i]],f[x][nx] = f[x][nx] + z[x];
node C = (f[to[i]][j] - now) + f[x][nx]; C.v += s,--C.k;
g[x] = max(g[x],C); f[to[i]][j] = f[to[i]][j] - z[to[i]],f[x][nx] = f[x][nx] - z[x];
}
for (int j = h[to[i]] ; ~j ; --j)
{
f[x][j + 1] = f[x][j + 1] + z[x],f[to[i]][j] = f[to[i]][j] + z[to[i]],f[to[i]][j].v += w[x];
f[x][j + 1] = max(f[x][j + 1],(S - now) + f[to[i]][j]);
f[to[i]][j] = f[to[i]][j] - z[to[i]],f[to[i]][j].v -= w[x],f[x][j + 1] = f[x][j + 1] - z[x];
if (j + 2 <= h[x]) f[x][j + 1] = max(f[x][j + 1],f[x][j + 2]);
}
f[x][0] = max(f[x][0],f[x][1]);
}
if (h[x] >= k) g[x] = max(g[x],f[x][k - 1] + z[x]);
}
inline void check(ll x)
{
for (int i = 0 ; i <= 3 * n ; ++i) c[i] = (node){-inf,0};
now = c; newc(1),dfs(1,0,x);
}
int main()
{
// freopen("a.in","r",stdin);
freopen("tree.in","r",stdin);
freopen("tree.out","w",stdout);
scanf("%d%d%d",&n,&m,&k);
for (int i = 1 ; i <= n ; ++i) scanf("%d",&w[i]);
for (int i = 1 ; i < n ; ++i)
{
int u,v; scanf("%d%d",&u,&v);
add(u,v);
} pre(1,0); ll l = -inf,r = inf,ans = r;
while (l <= r)
{
ll mid = (l + r) >> 1; check(mid);
if (g[1].k <= m) r = mid - 1,ans = mid;
else l = mid + 1;
}
check(ans);
printf("%lld\n",g[1].v + ans * m);
return 0;
}