题目大意: 给出一棵边带权的树,多组询问,每次给出若干个关键节点,要求删掉若干条边,使点 1 1 1 不能到达任意一个关键节点,问删掉的边的最小权值和是多少。
题解
先考虑只有 1 1 1 组询问的情况,可以考虑dp,以点 1 1 1 为根,设 f [ x ] f[x] f[x] 表示让 x x x 子树内的所有关键节点不能到达点 1 1 1 的最小代价。
那么转移有两种,一是删掉 x x x 到 1 1 1 的某条权值最小的边,或者是 ∑ y ∈ s o n f [ y ] \sum_{y\in son} f[y] ∑y∈sonf[y]。
但是多组询问的话时间复杂度就变成了 O ( n m ) O(nm) O(nm),考虑优化这个dp过程。
发现dp过程中,其实只有在关键点
和关键点的lca
处的决策有意义,所以可以对于每次询问,建出虚树然后再跑dp。
然后建虚树时还可以优化一下,发现假如关键点 x x x 在关键点 y y y 的子树内,那么就不用管 x x x 了,因为将 y y y 断掉之后 x x x 肯定也被断掉了。
代码如下:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
#define maxn 1000010
#define inf 999999999999999999ll
#define ll long long
int n,m;
struct edge{int x,y,z,next;}e[maxn<<1];
int first[maxn],len=0;
void buildroad(int x,int y,int z=0){e[++len]=(edge){x,y,z,first[x]};first[x]=len;}
int f[maxn][20],id[maxn],tot=0,deep[maxn];ll val[maxn];
void dfs(int x,int fa,ll Val,int dep){
id[x]=++tot;f[x][0]=fa;val[x]=Val;deep[x]=dep;
for(int i=first[x];i;i=e[i].next)if(e[i].y!=fa)dfs(e[i].y,x,min(Val,(ll)e[i].z),dep+1);
}
bool cmp(int x,int y){return id[x]<id[y];}
int a[maxn],zhan[maxn],t=0;
int get_lca(int x,int y)
{
if(deep[x]>deep[y])swap(x,y);
for(int i=19;i>=0;i--)if(deep[f[y][i]]>=deep[x])y=f[y][i];
if(x!=y){for(int i=19;i>=0;i--)if(f[x][i]!=f[y][i])x=f[x][i],y=f[y][i];x=f[x][0];}
return x;
}
void add(int x)//建虚树
{
if(t==1)return (void)(zhan[++t]=x);
int lca=get_lca(zhan[t],x);if(lca==zhan[t])return;
while(deep[zhan[t-1]]>=deep[lca])buildroad(zhan[t-1],zhan[t]),t--;
if(zhan[t]!=lca)buildroad(lca,zhan[t]),zhan[t]=lca; zhan[++t]=x;
}
ll g[maxn];void dp(int x)
{
ll tot=0;
for(int i=first[x];i;i=e[i].next)
dp(e[i].y),tot+=g[e[i].y];
if(tot)g[x]=min(val[x],tot);else g[x]=val[x];
}
int main()
{
scanf("%d",&n);for(int i=1,x,y,z;i<n;i++)
scanf("%d %d %d",&x,&y,&z),buildroad(x,y,z),buildroad(y,x,z);dfs(1,0,inf,1);
for(int j=1;j<=19;j++)for(int i=1;i<=n;i++)f[i][j]=f[f[i][j-1]][j-1];
scanf("%d",&m);for(int i=1,k;i<=m;i++)
{
while(len)first[e[len--].x]=0;
scanf("%d",&k);for(int j=1;j<=k;j++)scanf("%d",&a[j]);
sort(a+1,a+k+1,cmp);zhan[t=1]=1;for(int j=1;j<=k;j++)add(a[j]);
while(t>1)buildroad(zhan[t-1],zhan[t],0),t--;dp(1);printf("%lld\n",g[1]);
}
}