题目传送门:http://www.lydsy.com/JudgeOnline/problem.php?id=2286
题目分析:把虚树建出来,虚树上边的权值=对应链上的最小值。然后如果某个点是关键点,就一定要割掉其父边;否则可以选择割父边,也可以选择让其儿子自行处理。时间复杂度 O((n+∑k)log(n)) O ( ( n + ∑ k ) log ( n ) ) 。
我发现我写的虚树还是太少了,写代码的时候有很多SB的错误,比如tail写成last等等……QAQ
CODE:
#include<iostream>
#include<string>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<cstdlib>
#include<stdio.h>
#include<algorithm>
using namespace std;
const int maxn=300100;
const int maxl=21;
const long long oo=1e17;
typedef long long LL;
struct edge
{
int obj,len;
edge *Next;
} e[maxn<<3];
edge *head[maxn];
int cur=-1;
int fa[maxn][maxl];
int Min[maxn][maxl];
int dep[maxn];
int dfn[maxn];
int Time=0;
int Fa[maxn];
bool vis[maxn];
LL val[maxn];
LL f[maxn];
int Node[maxn];
int sak[maxn];
int cnt;
int tail;
int n,q,m;
void Add(int x,int y,int z)
{
cur++;
e[cur].obj=y;
e[cur].len=z;
e[cur].Next=head[x];
head[x]=e+cur;
}
void Dfs(int node)
{
dfn[node]=++Time;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
if (son==fa[node][0]) continue;
fa[son][0]=node;
Min[son][0]=p->len;
dep[son]=dep[node]+1;
Dfs(son);
}
}
bool Comp(int x,int y)
{
return dfn[x]<dfn[y];
}
int Lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
for (int j=maxl-1; j>=0; j--)
if (dep[ fa[x][j] ]>=dep[y]) x=fa[x][j];
if (x==y) return x;
for (int j=maxl-1; j>=0; j--)
if (fa[x][j]!=fa[y][j]) x=fa[x][j],y=fa[y][j];
return fa[x][0];
}
int Jump(int x,int y)
{
int temp=1e9;
for (int j=maxl-1; j>=0; j--)
if (dep[ fa[x][j] ]>=dep[y])
temp=min(temp,Min[x][j]),x=fa[x][j];
return temp;
}
void DP(int node)
{
if (vis[node]) f[node]=val[node];
else
{
LL sum=0;
for (edge *p=head[node]; p; p=p->Next)
{
int son=p->obj;
DP(son);
sum+=f[son];
}
f[node]=min(val[node],sum);
}
}
int main()
{
//freopen("2286.in","r",stdin);
//freopen("2286.out","w",stdout);
scanf("%d",&n);
for (int i=1; i<=n; i++) head[i]=NULL;
for (int i=1; i<n; i++)
{
int u,v,w;
scanf("%d%d%d",&u,&v,&w);
Add(u,v,w);
Add(v,u,w);
}
fa[1][0]=1;
Min[1][0]=(int)1e9;
dep[1]=1;
Dfs(1);
for (int j=1; j<maxl; j++)
for (int i=1; i<=n; i++)
{
int mid=fa[i][j-1];
fa[i][j]=fa[mid][j-1];
Min[i][j]=min(Min[i][j-1],Min[mid][j-1]);
}
scanf("%d",&q);
while (q--)
{
scanf("%d",&m);
for (int i=1; i<=m; i++) scanf("%d",&Node[i]);
sort(Node+1,Node+m+1,Comp);
cnt=m;
tail=1;
sak[1]=Node[1];
for (int i=2; i<=m; i++)
{
int p=Node[i];
int x=Lca(sak[tail],p);
int last=0;
while (dep[ sak[tail] ]>dep[x]) last=sak[tail--];
if (sak[tail]!=x) Fa[x]=sak[tail],sak[++tail]=Node[++cnt]=x;
if (last) Fa[last]=x;
sak[++tail]=p;
Fa[p]=x;
}
int root=sak[1];
if (root!=1)
{
Fa[root]=Node[++cnt]=1;
root=1;
}
Fa[root]=0;
for (int i=1; i<=cnt; i++) head[ Node[i] ]=NULL;
for (int i=1; i<=cnt; i++)
{
int p=Node[i];
if (Fa[p]) Add(Fa[p],p,0),val[p]=Jump(p,Fa[p]);
else val[p]=oo;
if (i<=m) vis[p]=true; else vis[p]=false;
}
DP(root);
printf("%lld\n",f[root]);
}
return 0;
}