题目大意
有一棵n个点的树,每条边有边权。有m次询问,每次给定k个关键点,问能切断根(1号点)到所有关键点的最小代价是多少?
n<=250000,m<=500000,
∑k
<=500000
Solution
可以发现,每次询问时只有关键点和关键点之间的LCA是有用的,知道了这些点,就能计算出答案。而且可以证明,这些点的总点数小于2k,总的复杂度可以变成O(
∑k
)。
用一个栈,就能建出这棵树。代码如下:
int top=1;
stack[top]=a[1];
for (int i=2;i<=cnt;i++)//cnt为要建树的节点数
{
while (top&&deep[stack[top]]>deep[lca(stack[top],a[i])])
top--;
if (stack[top]) add(stack[top],a[i],0);
stack[++top]=a[i];
}//a为要建树的所有节点。
建树的复杂度是
klogn
的。
建完这棵树后,在树上跑一次DP就可以。这个DP应该很显然了。
代码
#include<cstdio>
#include<algorithm>
#include<cstring>
using namespace std;
typedef long long ll;
const ll INF=1LL<<60;
int head[1000010],num,ti=0,dfn[1000010],next[2000010],vet[2000010],vel[2000010],flag[1000010],stack[1000010];
ll dp[1000010],dis[1000010];
int fa[1000010][22],deep[1000010],a[1000010],x[1000010];
bool cmp(int x,int y)
{
return dfn[x]<dfn[y];
}
void add(int u,int v,int s)
{
next[++num]=head[u];
head[u]=num;
vet[num]=v;
vel[num]=s;
}
void dfs(int u)
{
dfn[u]=++ti;
for (int i=head[u];i;i=next[i])
{
int v=vet[i];
if (v!=fa[u][0])
{
fa[v][0]=u,deep[v]=deep[u]+1,dis[v]=min((ll)vel[i],dis[u]);
dfs(v);
}
}
}
int lca(int u,int v)
{
if (deep[u]<deep[v]) swap(u,v);
for (int i=20;i>=0;i--)
if (fa[u][i]&&deep[fa[u][i]]>=deep[v]) u=fa[u][i];
if (u==v) return u;
for (int i=20;i>=0;i--)
if (fa[u][i]!=fa[v][i]) u=fa[u][i],v=fa[v][i];
return fa[u][0];
}
void DP(int u,int fa)
{
if (flag[u])
{
dp[u]=dis[u];
return;
}
dp[u]=dis[u];
ll sum=0;
int Flag=0;
for (int i=head[u];i;i=next[i])
{
int v=vet[i];
if (v!=fa)
{
DP(v,u);
sum+=dp[v];
Flag=1;
}
}
if (Flag) dp[u]=min(dp[u],sum);
}
int main()
{
int n;
scanf("%d",&n);
for (int i=1;i<n;i++)
{
int u,v,c;
scanf("%d%d%d",&u,&v,&c);
add(u,v,c);
add(v,u,c);
}
deep[1]=1;
dis[1]=INF;
dfs(1);
for (int i=1;i<=20;i++)
for (int j=1;j<=n;j++)
fa[j][i]=fa[fa[j][i-1]][i-1];
int m;
memset(head,0,sizeof(head));
scanf("%d",&m);
while (m--)
{
int K;
scanf("%d",&K);
num=0;
for (int i=1;i<=K;i++)
{
scanf("%d",&a[i]);
x[i]=a[i];
flag[x[i]]=1;
}
int cnt=K;
a[++cnt]=1;
sort(a+1,a+1+cnt,cmp);
cnt=unique(a+1,a+1+cnt)-a-1;
int tmp=cnt;
for (int i=1;i<tmp;i++)
a[++cnt]=lca(a[i],a[i+1]);
sort(a+1,a+1+cnt,cmp);
cnt=unique(a+1,a+1+cnt)-a-1;
int top=1;
stack[top]=a[1];
for (int i=2;i<=cnt;i++)
{
while (top&&deep[stack[top]]>deep[lca(stack[top],a[i])])
top--;
if (stack[top]) add(stack[top],a[i],0);
stack[++top]=a[i];
}
DP(1,0);
printf("%lld\n",dp[1]);
for (int i=1;i<=K;i++) flag[x[i]]=0;
for (int i=1;i<=cnt;i++) head[a[i]]=dp[a[i]]=0;
}
return 0;
}