题目大意:给出一棵有根树,n组询问,每一组询问给出树上的一些关键点,问割掉一些边使得根与这些点不联通的最小花费是多少。总询问的点不超过O(n)。
思路:基础思路是每一次询问做一次O(n)的DP,这本来已经够快了,但是有很多询问,这样做就n^2了。注意到所有询问的点加起来不超过O(n),也就是说每次询问的点可能很少。那么我们为何要将所有点扫一次?只需要将询问的点重新建树,然后跑树形DP,这样DP的总时间就是O(n)了。当然瓶颈在求两点之间的最短边上,O(nlogn)的倍增。
具体做法是维护一个单调栈,所有时刻这个栈中的所有点是从根开始的深度递增的一条链。把所有点按照DFS序排序,依次加入栈中,同时维护这个栈,使它是一条链。假如新加进来的点与栈顶的LCA高于栈顶,那么就说明新加进来的点不能继续与栈顶形成链了。就将栈顶和次栈顶连边,然后弹出栈顶。还有一些小细节什么的。。
CODE:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define MAX 510010
#define INF 0x3f3f3f3f
using namespace std;
struct Complex{
int x,pos;
Complex(int _,int __):x(_),pos(__) {}
Complex() {}
bool operator <(const Complex &a)const {
return pos < a.pos;
}
}src[MAX];
int points,asks;
int head[MAX],total;
int next[MAX],aim[MAX],length[MAX];
int pos[MAX],cnt;
inline void Add(int x,int y,int len)
{
next[++total] = head[x];
aim[total] = y;
length[total] = len;
head[x] = total;
}
int father[MAX][20],_min[MAX][20];
int deep[MAX];
void DFS(int x,int last)
{
deep[x] = deep[last] + 1;
pos[x] = ++cnt;
for(int i = head[x]; i; i = next[i]) {
if(aim[i] == last) continue;
father[aim[i]][0] = x;
_min[aim[i]][0] = length[i];
DFS(aim[i],x);
}
}
void MakeTable()
{
for(int j = 1; j < 19; ++j)
for(int i = 1; i <= points; ++i) {
father[i][j] = father[father[i][j - 1]][j - 1];
_min[i][j] = min(_min[i][j - 1],_min[father[i][j - 1]][j - 1]);
}
}
inline int GetLCA(int x,int y)
{
if(deep[x] < deep[y]) swap(x,y);
for(int i = 19; ~i; --i)
if(deep[father[x][i]] >= deep[y])
x = father[x][i];
if(x == y) return x;
for(int i = 19; ~i; --i)
if(father[x][i] != father[y][i])
x = father[x][i],y = father[y][i];
return father[x][0];
}
inline int GetMin(int x,int y)
{
if(deep[x] < deep[y]) swap(x,y);
int re = INF;
for(int i = 19; ~i; --i)
if(deep[father[x][i]] >= deep[y]) {
re = min(re,_min[x][i]);
x = father[x][i];
}
for(int i = 19; ~i; --i)
if(father[x][i] != father[y][i]) {
re = min(re,_min[x][i]);
re = min(re,_min[y][i]);
x = father[x][i];
y = father[y][i];
}
if(x != y) re = min(re,min(_min[x][0],_min[y][0]));
return re;
}
struct Graph{
int head[MAX],v[MAX],T,total;
int next[MAX],aim[MAX];
int super[MAX];
long long f[MAX];
void Add(int x,int y) {
//cout << x << ' ' << y << endl;
if(v[x] != T) {
v[x] = T;
head[x] = 0;
}
next[++total] = head[x];
aim[total] = y;
head[x] = total;
}
void Set(int x) {
super[x] = T;
}
void TreeDP(int x) {
f[x] = 0;
if(v[x] != T) {
v[x] = T;
head[x] = 0;
}
for(int i = head[x]; i; i = next[i]) {
TreeDP(aim[i]);
f[x] += min(super[aim[i]] == T ? INF:f[aim[i]],(long long)GetMin(x,aim[i]));
}
}
}graph;
int main()
{
cin >> points;
for(int x,y,z,i = 1; i < points; ++i) {
scanf("%d%d%d",&x,&y,&z);
Add(x,y,z),Add(y,x,z);
}
DFS(1,0);
MakeTable();
cin >> asks;
for(int cnt,i = 1; i <= asks; ++i) {
scanf("%d",&cnt);
for(int j = 1; j <= cnt; ++j)
scanf("%d",&src[j].x),src[j].pos = pos[src[j].x];
sort(src + 1,src + cnt + 1);
++graph.T;
graph.total = 0;
static int stack[MAX];
int top = 0;
stack[++top] = 1;
for(int j = 1; j <= cnt; ++j) {
int lca = GetLCA(stack[top],src[j].x);
while(deep[lca] < deep[stack[top]]) {
if(deep[stack[top - 1]] <= deep[lca]) {
int away = stack[top--];
if(stack[top] != lca)
stack[++top] = lca;
graph.Add(lca,away);
break;
}
graph.Add(stack[top - 1],stack[top]),--top;
}
if(stack[top] != src[j].x)
stack[++top] = src[j].x;
graph.Set(src[j].x);
}
while(top)
graph.Add(stack[top - 1],stack[top]),--top;
graph.TreeDP(1);
printf("%lld\n",graph.f[1]);
}
return 0;
}