题目:
题解:
虚树,一种可以快速优化树上dp的东西,将每次基于n(节点数)的询问转化为基于
∑m
∑
m
的,Emmm,虚树就是介个东西吧
除了询问节点外,任意两点的lca都会存在于虚树中,叙述中两点间的路径,比如说图二的2-18这条边要存储2-4-18这条链的信息
对于这道题目来说,我们先考虑基本的dp,dp[i]表示把以i为子树根节点的关键点全都切掉后的最小代价
那么如果i是关键点dp[i]就是ta到父亲的边;如果i不是关键点,dp[i]就是【所有子节点dp[v]的和】和【自己通往父节点这条边】的最小值
如果我们就按照这个思路的话可能只有40pts?可以发现这题询问很多,显然每次询问我们必须用跟k相关的时间解决,而不能跟n相关。
我们可以发现这些询问点之外的点是没有用的,虚树之间的边连树边最小值就好了,这样我们生成的树就只有关键信息了,而且你如果建树就用了
O(n2)
O
(
n
2
)
还不如不建,利用单调栈我们可以达到
O(logn)
O
(
l
o
g
n
)
的级别
而且你每次建一棵树都要memset的话就又带来O(n)的复杂度了,但是不清又不行,那就在递归里清数组吧
代码:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define LL long long
#define INF 1e9
using namespace std;
const int sz=19;
const int N=250005;
int n,m,tot,nxt[N*2],point[N],v[N*2],c[N*2],ask[N],h[N],mi[sz],f[N][sz],s[N][sz],in[N],out[N],nn,last[N],gj[N],flag[N],stack[N],top;
LL dp[N];
void addline(int x,int y,int w)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y; c[tot]=w;
}
void dfs(int x,int fa)
{
in[x]=++nn;h[x]=h[fa]+1;
for (int i=1;i<sz;i++)
if (h[x]<mi[i]) break;
else f[x][i]=f[f[x][i-1]][i-1],s[x][i]=min(s[x][i-1],s[f[x][i-1]][i-1]);
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa) s[v[i]][0]=c[i],f[v[i]][0]=x,dfs(v[i],x);
out[x]=nn;
}
int lca(int x,int y)
{
if (h[x]<h[y]) swap(x,y);
int k=h[x]-h[y];
for (int i=0;i<sz;i++)
if (k&(1<<i)) x=f[x][i];
if (x==y) return x;
for (int i=sz-1;i>=0;i--)
if (f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int cmp(int x,int y){return in[x]<in[y];}
int find(int x,int y)
{
int minn=INF,k=h[y]-h[x];
for (int i=0;i<sz;i++)
if (k&(1<<i)) minn=min(minn,s[y][i]),y=f[y][i];
return minn;
}
void treedp(int x)
{
dp[x]=0;
for (int i=point[x];i;i=nxt[i])
{
last[v[i]]=c[i];
treedp(v[i]);
dp[x]+=dp[v[i]];
}
if (gj[x]==m) dp[x]=(LL)last[x];
else if (x!=1) dp[x]=min(dp[x],(LL)last[x]);
point[x]=0;
}
void work()
{
int k;scanf("%d",&k);
for (int i=1;i<=k;i++)
{
scanf("%d",&ask[i]); gj[ask[i]]=flag[ask[i]]=m;
}
sort(ask+1,ask+k+1,cmp);//深度小的-大的
ask[0]=k;
for (int i=2;i<=k;i++)
{
int t=lca(ask[i],ask[i-1]);
if (flag[t]!=m) flag[t]=m,ask[++ask[0]]=t;
}
if (flag[1]!=m) flag[1]=m,ask[++ask[0]]=1;
sort(ask+1,ask+ask[0]+1,cmp);
tot=0; stack[top=1]=1;
for (int i=2;i<=ask[0];i++)
{
while (in[ask[i]]<in[stack[top]] || in[ask[i]]>out[stack[top]]) --top;
int minn=find(stack[top],ask[i]);
addline(stack[top],ask[i],minn);//因为一定是上连下,所以不用连双向边
stack[++top]=ask[i];
}
treedp(1);
printf("%lld\n",dp[1]);
}
int main()
{
mi[0]=1;
for (int i=1;i<sz;i++) mi[i]=mi[i-1]*2;
scanf("%d",&n);
for (int i=1;i<n;i++)
{
int x,y,w;
scanf("%d%d%d",&x,&y,&w);
addline(x,y,w);addline(y,x,w);
}
dfs(1,0);memset(point,0,sizeof(point));
scanf("%d",&m);
while (m) work(),m--;
}