分析:
显然是一道dp,那我们就想方程吧
一开始dp的方程不是很成熟:
设计了一个状态
f[i][0/1]
f
[
i
]
[
0
/
1
]
,表示是否选择第
i
i
个结点和根结点的路径
如果选择了结点和根结点的路径,那么就没有必要选子树中的路径了
如果没选择,就需要选子树中的路径,但是这个转移方程不是很明确
实际上我们直接用最简单的dp就可以了:
f[i]
f
[
i
]
表示到
i
i
结点的最小花费
那么每个结点有两种可能:选还是不选
选的花费就是,不选的话就统计一下子结点的
f
f
值(递归下去)
朴素的dp每次询问都是
O(n)
O
(
n
)
的复杂度,显然会T啊
注意题目中有限制
Σk<=500000
Σ
k
<=
500000
,因此如果能将一次的时间复杂度减小到
O(k)
O
(
k
)
或者
O(klogk)
O
(
k
l
o
g
k
)
,就能通过了
因此,我们的关键是能构造出一棵结点
<=O(k)
<=
O
(
k
)
<script type="math/tex" id="MathJax-Element-13"><=O(k)</script>级别的虚树,并能在
O(k)
O
(
k
)
或者
O(klogk)
O
(
k
l
o
g
k
)
的时间构造完成
简单说一下吧:
定义某一次询问给出的岛屿为关键点
注意到对于某对关键点
(x,y)
(
x
,
y
)
假设
x−>lca(x,y)
x
−
>
l
c
a
(
x
,
y
)
的路径中,没有点是某一对关键点的
lca
l
c
a
,也没有其他关键点,我们只用考虑这两个结点
我们要在这条路径上选一条边并删除,显然我们会贪心的选择最小的边
这样我们只要得到
x−>lca(x,y)
x
−
>
l
c
a
(
x
,
y
)
中的最小第一条边即可,其他路径上的点对答案不会产生任何影响
换句话说将
lca(x,y)−>...−>x
l
c
a
(
x
,
y
)
−
>
.
.
.
−
>
x
的路径直接压缩成
lca(x,y)−>x
l
c
a
(
x
,
y
)
−
>
x
,对答案不会产生影响
因此我们只需要保留所有关键点,以及它们两两之间的
lca
l
c
a
按照原树的祖先关系连边,在构造得到的虚树上面跑dp即可
因为
k
k
个点两两之间不同的 只有
k−1
k
−
1
个,因此产生的虚树是
O(k)
O
(
k
)
的
tip
build虚树的时候单向连边即可
dp的时候处理完一个结点,清空st:st[now]=0;
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
#include<cmath>
#define ll long long
using namespace std;
const ll INF=1e18;
const int N=500010;
struct node{
int y,nxt;
ll v;
};
node way[N<<1];
int lg,deep[N],pre[N][20],in[N],clo,a[N],n,m,S[N],top,st[N],tot=0;
ll len[N][20],f[N];
void add(int u,int w,ll z)
{
tot++;
way[tot].y=w;way[tot].nxt=st[u];st[u]=tot;way[tot].v=z;
tot++;
way[tot].y=u;way[tot].nxt=st[w];st[w]=tot;way[tot].v=z;
}
void dfs(int now,int fa,int dep)
{
deep[now]=dep;
pre[now][0]=fa; in[now]=++clo;
for (int i=st[now];i;i=way[i].nxt)
if (way[i].y!=fa)
{
len[way[i].y][0]=way[i].v;
dfs(way[i].y,now,dep+1);
}
}
int lca(int x,int y)
{
if (deep[x]<deep[y]) swap(x,y);
int d=deep[x]-deep[y];
if (d)
for (int i=0;i<=lg&&d;i++,d>>=1)
if (d&1)
x=pre[x][i];
if (x==y) return x;
for (int i=lg;i>=0;i--)
if (pre[x][i]!=pre[y][i])
{
x=pre[x][i];
y=pre[y][i];
}
return pre[x][0];
}
void prepare()
{
clo=0;
memset(len,127,sizeof(len));
dfs(1,0,1);
for (int i=1;i<=lg;i++)
for (int j=1;j<=n;j++)
pre[j][i]=pre[pre[j][i-1]][i-1],
len[j][i]=min(len[j][i-1],len[pre[j][i-1]][i-1]);
}
ll getlen(int x,int y)
{
ll sum=INF;
if (deep[x]<deep[y]) swap(x,y);
int d=deep[x]-deep[y];
if (d)
for (int i=0;i<=lg&&d;i++,d>>=1)
if (d&1)
sum=min(sum,len[x][i]),x=pre[x][i];
if (x==y) return sum;
for (int i=lg;i>=0;i--)
if (pre[x][i]!=pre[y][i])
{
sum=min(sum,len[x][i]);
sum=min(sum,len[y][i]);
x=pre[x][i];
y=pre[y][i];
}
sum=min(sum,len[x][0]); sum=min(sum,len[y][0]);
return sum;
}
int cmp(int a,int b)
{
return in[a]<in[b];
}
void build(int x,int y)
{
if (x==y) return;
tot++;
ll t=getlen(x,y);
way[tot].y=y;way[tot].nxt=st[x];st[x]=tot;way[tot].v=t;
}
void dp(int now,ll mn)
{
f[now]=mn;
ll sum=0;
for (int i=st[now];i;i=way[i].nxt)
{
dp(way[i].y,way[i].v);
sum+=f[way[i].y];
}
st[now]=0; //清空st
if (sum) f[now]=min(f[now],sum); //if (sum)
}
void solve()
{
int k;
scanf("%d",&k);
for (int i=1;i<=k;i++) scanf("%d",&a[i]);
sort(a+1,a+1+k,cmp);
int cnt=0;
a[++cnt]=a[1];
for (int i=2;i<=k;i++)
if (lca(a[cnt],a[i])!=a[cnt]) a[++cnt]=a[i];
k=cnt;
tot=0; top=0; S[++top]=1;
for (int i=1;i<=k;i++)
{
int now=a[i];
int p=lca(now,S[top]);
while (1)
{
if (deep[p]>=deep[S[top-1]])
{
build(p,S[top--]);
if (S[top]!=p) S[++top]=p;
break;
}
build(S[top-1],S[top]);
top--;
}
if (now!=S[top]) S[++top]=now;
}
while (top-1) build(S[top-1],S[top]),top--;
dp(1,INF);
printf("%lld\n",f[1]);
}
int main()
{
scanf("%d",&n); lg=log(n)/log(2);
for (int i=1;i<n;i++)
{
int u,w,z;
scanf("%d%d%d",&u,&w,&z);
add(u,w,z);
}
prepare();
scanf("%d",&m);
memset(st,0,sizeof(st));
for (int i=1;i<=m;i++)
solve();
return 0;
}