题意:在一棵树上有多个块,每一块包含树上的多个点,查询两个块间的最短距离(距离定义为在两个块中各取一个点,两点在树上的距离)。
想法:
实测每个块包含的点并不多,因此可以两重循环遍历所有可能的点对取最短距离。因此问题转化为求树上两点间的距离。
nlogn预处理每一个点i到其2^j级祖先的距离,之后即可以logn时间处理每次查询。
#include<cstdio>
#include<cstring>
#include<string>
#include<cmath>
#include<iostream>
#include<algorithm>
#include<queue>
#include<stack>
#include<set>
#include<map>
#include<deque>
#include<vector>
#include<functional>
using namespace std;
#define LL long long
#define mm(a,b) memset(a,b,sizeof(a))
const double eps=1.0e-6;
const double PI=acos(-1.0);
template<typename T>T gcd(T a,T b){return b==0?a:gcd(b,a%b);}
template<typename T>T lcm(T a,T b){return a/gcd(a,b)*b;}
template<typename T>T _abs(T a){return a>0?a:-a;}
typedef pair<int,int> P;
const int maxn=100010;
const int inf=1<<30;
struct edge
{
int to,next,w;
}e[maxn*2];
int cnt,head[maxn];
int dep[maxn];//深度;
int fa[maxn],w[maxn];//父亲;到父亲的距离
vector<int> v[maxn];
int sz[maxn];
void add(int x,int y,int z)
{
e[cnt].to=y;
e[cnt].w=z;
e[cnt].next=head[x];
head[x]=cnt++;
}
void dfs(int rt,int d,int f)
{
dep[rt]=d;
fa[rt]=f;
for(int k=head[rt];~k;k=e[k].next)
if(e[k].to!=f)
{
w[e[k].to]=e[k].w;
dfs(e[k].to,d+1,rt);
}
}
//模板
int anc[maxn][20],cost[maxn][20];
void preprocess(int n)
{
for(int i=1;i<=n;i++)
{
anc[i][0]=fa[i];
cost[i][0]=w[i];
for(int j=1;(1<<j)<n;j++)
anc[i][j]=-1;
}
for(int j=1;(1<<j)<n;j++)
for(int i=1;i<=n;i++)
if(anc[i][j-1]!=-1)
{
int x=anc[i][j-1];
anc[i][j]=anc[x][j-1];
cost[i][j]=cost[i][j-1]+cost[x][j-1];
}
}
int query(int p,int q)
{
int log;
if(dep[p]<dep[q])
swap(p,q);
for(log=1;(1<<log)<=dep[p];log++);
log--;
int ans=0;
for(int i=log;i>=0;i--)
if(dep[p]-(1<<i)>=dep[q])
{
ans+=cost[p][i];
p=anc[p][i];
}
if(p==q)
return ans;
for(int i=log;i>=0;i--)
if(anc[p][i]!=-1&&anc[p][i]!=anc[q][i])
{
ans+=cost[p][i];
p=anc[p][i];
ans+=cost[q][i];
q=anc[q][i];
}
ans+=w[p]+w[q];
return ans;
}
int main()
{
int n,m,t,x,y,z;
scanf("%d",&t);
while(t--)
{
scanf("%d%d",&n,&m);
cnt=0;
mm(head,-1);
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
dfs(1,0,-1);
for(int i=1;i<=m;i++)
{
scanf("%d",&x);
v[i].clear();
for(int j=1;j<=x;j++)
{
scanf("%d",&y);
v[i].push_back(y);
}
sort(v[i].begin(),v[i].end());
sz[i]=unique(v[i].begin(),v[i].end())-v[i].begin();//去重
}
preprocess(n);
scanf("%d",&x);
while(x--)
{
scanf("%d%d",&y,&z);
int ans=inf;
for(int i=0;i<sz[y];i++)
for(int j=0;j<sz[z];j++)
ans=min(ans,query(v[y][i],v[z][j]));
printf("%d\n",ans);
}
}
return 0;
}