题意
一棵有正边权的树上, m m m 次询问从 x x x 号点走到节点编号在 l l l 和 r r r 之间的节点的最小距离。 n , m ≤ 1 0 5 n,m\leq 10^5 n,m≤105。时限 2s。
题解
(似乎有很多写法)
先把所有询问离线下来,并把它拆成 O ( log ) O(\log) O(log) 段 ( x , l ′ , r ′ ) (x,l',r') (x,l′,r′) 挂在线段树的各个节点。
遍历线段树的每个节点,把 这个节点对应编号的点 和 挂在这个节点上的询问的 x x x 放在虚树里,树形 DP 一下虚树中每个点到 这个线段树节点对应编号的点 的最小距离,然后用这些距离更新每个询问的答案。
(DP 本来不用 DFS 的 QAQ,因为 DFS 序都求好了)
实现得不够好的话时间复杂度为 O ( ( n + m ) log 2 n ) O((n+m)\log^2n) O((n+m)log2n)。假如实现了 O ( 1 ) O(1) O(1) 查询 LCA,并且构建虚树前把询问按 x x x 的 DFS 序排序,构建虚树时把线段树上左儿子的对应编号的点、右儿子的对应编号的点、这个挂在这个节点上的询问的 x x x 三个序列归并起来,可以达到 O ( ( n + m ) log n ) O((n+m)\log n) O((n+m)logn)。
代码( O ( ( n + m ) log 2 n ) O((n+m)\log^2n) O((n+m)log2n)):
#include<bits/stdc++.h>
using namespace std;
int getint(){
int ans=0;
char c=getchar();
while(c<'0'||c>'9')c=getchar();
while(c>='0'&&c<='9'){
ans=ans*10+c-'0';
c=getchar();
}
return ans;
}
const int N=2e5+10,L=20;
struct bian{
int l,e,n;
};
bian b[N];
int s[N],tot=0;
int n;
void add(int x,int y,int z){
tot++;
b[tot].l=z;
b[tot].e=y;
b[tot].n=s[x];
s[x]=tot;
}
int dep[N],valdep[N],dfn[N<<2],dfnn=0;
int d[N],e[N];
void ss(int x,int f){
dfn[++dfnn]=x;
e[x]=d[x]=dfnn;
for(int i=s[x];i;i=b[i].n){
if(b[i].e==f)continue;
dep[b[i].e]=dep[x]+1;
valdep[b[i].e]=valdep[x]+b[i].l;
ss(b[i].e,x);
dfn[++dfnn]=x;
e[x]=dfnn;
}
}
int st[L][N<<2],l2[N<<2];
int Min(int x,int y){
return dep[x]<dep[y]?x:y;
}
void init_st(){
l2[0]=-1;for(int i=1;i<=dfnn;i++)st[0][i]=dfn[i],l2[i]=l2[i>>1]+1;
for(int i=1;i<L;i++){
for(int j=1;j<=dfnn-(1<<i-1);j++){
st[i][j]=Min(st[i-1][j],st[i-1][j+(1<<i-1)]);
}
}
}
int get_lca(int x,int y){
x=d[x];
y=e[y]+1;
int t=l2[y-x];
return Min(st[t][x],st[t][y-(1<<t)]);
}
int quel[N],quer[N],quex[N],ans[N];
vector<int>a[N<<2];
void add_query(int l,int r,int val,int x,int nl,int nr){
if(nr<l||nl>r)return;
if(l<=nl&&nr<=r){
a[x].push_back(val);
return;
}
int mid=nl+nr>>1;
add_query(l,r,val,x<<1,nl,mid);
add_query(l,r,val,x<<1|1,mid+1,nr);
}
bian vb[N];
int vs[N],vtot=0;
void vadd(int x,int y,int z){
vtot++;
vb[vtot].l=z;
vb[vtot].e=y;
vb[vtot].n=vs[x];
vs[x]=vtot;
}
int nodes[N],knodes[N],cnt=0;
bool cmp(int x,int y){
return d[x]<d[y];
}
int fch[N],ffa[N];
void vdp1(int x,int f,int l,int r){
if(l<=x&&x<=r)fch[x]=0;
for(int i=vs[x];i;i=vb[i].n){
if(vb[i].e==f)continue;
vdp1(vb[i].e,x,l,r);
fch[x]=min(fch[vb[i].e]+vb[i].l,fch[x]);
}
//cerr<<"fch "<<x<<" "<<fch[x]<<" ("<<l<<" "<<r<<")"<<endl;
}
void vdp2(int x,int f,int l,int r){
ffa[x]=min(ffa[x],fch[x]);
for(int i=vs[x];i;i=vb[i].n){
if(vb[i].e==f)continue;
ffa[vb[i].e]=min(ffa[vb[i].e],ffa[x]+vb[i].l);
vdp2(vb[i].e,x,l,r);
}
//cerr<<"ffa "<<x<<" "<<ffa[x]<<" ("<<fch[x]<<")"<<endl;
}
void vclear(int x,int f){
fch[x]=ffa[x]=0x3f3f3f3f;
for(int i=vs[x];i;i=vb[i].n){
if(vb[i].e==f)continue;
vclear(vb[i].e,x);
}
}
void vtree(int x,int l,int r){
//for(int i=0;i<a[x].size();i++)cerr<<" "<<quex[a[x][i]];cerr<<endl;
cnt=0;
for(int i=l;i<=r;i++)nodes[cnt++]=i;
for(int i=0;i<a[x].size();i++)nodes[cnt++]=quex[a[x][i]];
sort(nodes,nodes+cnt,cmp);
for(int i=0;i<cnt;i++)knodes[i]=nodes[i];
for(int i=0;i<cnt-1;i++)knodes[i+cnt]=get_lca(nodes[i],nodes[i+1]);
sort(knodes,knodes+cnt+cnt-1,cmp);
int len=unique(knodes,knodes+cnt+cnt-1)-knodes;
stack<int>sta;
for(int i=0;i<len;i++){
while(sta.size()&&d[knodes[i]]>e[sta.top()])sta.pop();
if(sta.size()){
int w=abs(valdep[knodes[i]]-valdep[sta.top()]);
vadd(knodes[i],sta.top(),w),
vadd(sta.top(),knodes[i],w);
//cerr<<"> "<<knodes[i]<<" "<<sta.top()<<" ["<<w<<"]"<<endl;
}
sta.push(knodes[i]);
}
//cerr<<"vdp1:"<<endl;
vdp1(knodes[0],0,l,r);
//cerr<<"vdp2:"<<endl;
vdp2(knodes[0],0,l,r);
for(int i=0;i<a[x].size();i++){
ans[a[x][i]]=min(ans[a[x][i]],ffa[quex[a[x][i]]]);
//cerr<<"ans "<<a[x][i]<<" < "<<ffa[quex[a[x][i]]]
// <<" ["<<l<<" "<<r<<"]"<<endl;
}
vclear(knodes[0],0);
for(int i=0;i<len;i++)vs[knodes[i]]=0,knodes[i]=0;
for(int i=0;i<cnt;i++)nodes[i]=0;
vtot=len=cnt=0;
}
void solve(int x,int l,int r){
if(l<r){
int mid=l+r>>1;
solve(x<<1,l,mid);
solve(x<<1|1,mid+1,r);
}
//cerr<<"solve "<<x<<" "<<l<<" "<<r<<endl;
vtree(x,l,r);
}
int main(){
n=getint();
for(int i=1;i<n;i++){
int x=getint(),y=getint(),z=getint();
add(x,y,z);
add(y,x,z);
}
ss(1,0);
init_st();
int m=getint();
for(int i=1;i<=m;i++){
quel[i]=getint();
quer[i]=getint();
quex[i]=getint();
ans[i]=0x3f3f3f3f;
add_query(quel[i],quer[i],i,1,1,n);
}
memset(fch,0x3f,sizeof(fch));
memset(ffa,0x3f,sizeof(ffa));
solve(1,1,n);
for(int i=1;i<=m;i++)printf("%d\n",ans[i]);
return 0;
}