题意:有一颗树,树上有黑点和白点,问两点间满足黑点不超过K个的简单路径的最大边权之和是多少。
题解:考虑点分治,以到分治中心的黑点数量为下标建树状数组维护每个点到分治中心的边权之和,合并答案时查询黑点不超过 k的最大值。这样做的复杂度为 n ∗ l o g 2 n n*log^2n n∗log2n,不足以通过 2 e 5 2e5 2e5的数据。
不用数据结构,考虑维护 g[i][k][j]:当前分治中心为 i,前 k 棵子树中,到分治中心黑点个数不超过j个的最大边权和,i,k都可省略,简单表示为 g(v,j):前 v棵子树黑点不超过 j 的最大边权和,扫描每一棵子树时维护一个tmp[j] 表示当前子树中到分治中心黑点个数为 j个的最大边权和。显然当扫描到第v棵子树时: ans = max(ans,tmp[j] + g(v - 1,k - j))。考虑如何维护 g(v,j):当 v > 1时,tmp[j] = max(tmp[j - 1,tmp[j]];g(v,j) = max(g(v - 1,j,tmp[j])。
这样维护复杂度最差会达到 O ( n 2 ) O(n ^ 2) O(n2),(设刚好按每颗子树的 j 的递减的顺序合并,每一次都要遍历max(j)次)
注意子树的合并顺序并不会影响答案,所以可以提前扫描一遍子树,按 j 从小到大的顺序来合并子树(这就是启发式合并),这样做的复杂度是 O ( n ) O(n) O(n),排序的复杂度是均摊的,不会超过 n l o g n nlogn nlogn,总体复杂度为 O ( n l o g n ) O(nlogn) O(nlogn)
#include<bits/stdc++.h>
using namespace std;
#define pii pair<int,int>
#define fir first
#define sec second
const int maxm = 2e6 + 10;
const int maxn = 3e5 + 10;
int head[maxm],to[maxm],nxt[maxm],w[maxm],cnt;
int root,sz[maxn],f[maxn],tot,black[maxn],g[maxn],tmp[maxn];
int res;
bool done[maxn];
int n,k,m;
struct node{
int v,d,val;
node(int vi = 0,int di = 0,int vl = 0) {
v = vi;d = di;val = vl;
}
bool operator < (const node & rhs) const {
return val < rhs.val;
}
};
vector<node> h;
void init() {
res = cnt = 0;
fill(head,head + n + 1,-1);
fill(done,done + n + 1,0);
fill(g,g + n + 1,0);
fill(tmp,tmp + n + 1,0);
fill(black,black + n + 1,0);
}
void add(int u,int v,int wi) {
to[cnt] = v;
nxt[cnt] = head[u];
w[cnt] = wi;
head[u] = cnt++;
}
void getroot(int u,int fa) {
sz[u] = 1;f[u] = 0;
for(int i = head[u]; i + 1; i = nxt[i]) {
if(to[i] == fa || done[to[i]]) continue;
getroot(to[i],u);
sz[u] += sz[to[i]];
f[u] = max(f[u],sz[to[i]]);
}
f[u] = max(f[u],tot - sz[u]);
if(!root || f[u] < f[root]) root = u;
}
void dfs_dep(int u,int fa,int val,int &dep) {
if(dep > k) return;
dep = max(dep,val);
for(int i = head[u]; i + 1; i = nxt[i]) {
if(to[i] == fa || done[to[i]]) continue;
dfs_dep(to[i],u,val + black[to[i]],dep);
}
}
void dfs_val(int u,int fa,int dep,int val) {
if(dep > k) return;
tmp[dep] = max(tmp[dep],val);
for(int i = head[u]; i + 1; i = nxt[i]) {
if(to[i] == fa || done[to[i]]) continue;
dfs_val(to[i],u,dep + black[to[i]],val + w[i]);
}
}
int solve(int u) {
int ans = 0;
done[u] = true;h.clear();
int td = black[u];
for(int i = head[u]; i + 1; i = nxt[i]) {
if(done[to[i]]) continue;
int mxdep = 0;
dfs_dep(to[i],u,black[u] + black[to[i]],mxdep);
h.push_back(node(to[i],w[i],mxdep));
}
sort(h.begin(),h.end());
for(int i = 0; i < h.size(); i++) {
dfs_val(h[i].v,u,black[u] + black[h[i].v],h[i].d);
if(i) {
int cur = 0;
for(int j = h[i].val; j >= 0; j--) {
cur = min(k - j + black[u],h[i - 1].val);
ans = max(ans,g[cur] + tmp[j]);
}
}
for(int j = h[i].val; j >= 0; j--)
ans = max(ans,tmp[j]);
for(int j = 0; j <= h[i].val; j++) {
g[j] = max(g[j],tmp[j]);
tmp[j] = 0;
}
for(int j = 1; j <= h[i].val; j++)
g[j] = max(g[j],g[j - 1]);
}
if(h.size())
for(int i = 0; i <= h[h.size() - 1].val; i++) g[i] = 0;
return ans;
}
void divide(int rt) {
res = max(res,solve(rt));
for(int i = head[rt]; i + 1; i = nxt[i]) {
if(done[to[i]]) continue;
root = 0;tot = sz[to[i]];
getroot(to[i],rt);
divide(root);
}
}
int main() {
while(~scanf("%d%d%d",&n,&k,&m)) {
init();
for(int i = 1; i <= m; i++) {
int x;scanf("%d",&x);
black[x] = 1;
}
for(int i = 1; i < n; i++) {
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
add(u,v,w);add(v,u,w);
}
tot = n;root = 0;
getroot(1,-1);
divide(root);
printf("%d\n",res);
}
return 0;
}