arc 048 d
假设要从 s s s到 t t t,中途在 i i i的位置出去走到一个特殊点,然后再回来走到t。
可以发现从i出去走到的关键点一定是距离i最近的那个关键点。
首先处理出所有点到最近的关键点的距离,记为 d i s i dis_i disi。
则从i走到最近的那个关键点再走回i需要花费 3 × d i s i 3\times dis_i 3×disi的时间。
然后再分类讨论i的位置:
- 在 s s s到 l c a ( s , t ) lca(s,t) lca(s,t)的路径上。则花费的时间为:
2 ∗ d e p t h [ s ] − 2 ∗ d e p t h [ i ] + d e p t h [ i ] − d e p t h [ l c a ] + 3 d i s [ i ] + d e p t h [ t ] − d e p t h [ l c a ] 2*depth[s]-2*depth[i]+depth[i]-depth[lca]+3dis[i]+depth[t]-depth[lca] 2∗depth[s]−2∗depth[i]+depth[i]−depth[lca]+3dis[i]+depth[t]−depth[lca]
也就是说只需要找到路径$s$到lca上最小的-depth[i]+3dis[i]。
- 在lca到t的路径上,花费的时间为:
2 d e p t h [ s ] − 2 d e p t h [ l c a ] + d e p t h [ t ] − d e p t h [ i ] + 2 d e p t h [ i ] − 2 d e p t h [ l c a ] + 3 d i s [ i ] 2depth[s]-2depth[lca]+depth[t]-depth[i]+2depth[i]-2depth[lca]+3dis[i] 2depth[s]−2depth[lca]+depth[t]−depth[i]+2depth[i]−2depth[lca]+3dis[i]
与上面类似,只需要找到在lca到t的路径上最小的depth[i]+3dis[i]。
由于没有修改,可以用倍增算最小值。
IH19980412
#include <bits/stdc++.h>
using namespace std;
typedef pair<int,int> P;
#define fi first
#define sc second
#define mp make_pair
#define pb push_back
#define mod 1000000007
typedef long long ll;
int n,m; string s;
vector<int>e[100005];
int db[100005][18];
int d[100005][18];
int ee[100005][18];
vector<int>T;
int ds[100005],D[100005];
bool used[100005];
int dfs(int v,int u){
db[v][0] = u;
for(int i=0;i<e[v].size();i++){
if(e[v][i] == u) continue;
D[e[v][i]] = D[v]+1;
dfs(e[v][i],v);
}
}
int rd(int s,int g){
int x = 10000000;
for(int i=17;i>=0;i--){
if(db[s][i] != -1 && D[g] <= D[db[s][i]]){
x = min(x,d[s][i]);
s = db[s][i];
}
}
return min(x,d[g][0]);
}
int re(int s,int g){
int x = 10000000;
for(int i=17;i>=0;i--){
if(db[s][i] != -1 && D[g] <= D[db[s][i]]){
x = min(x,ee[s][i]);
s = db[s][i];
}
}
return min(x,ee[g][0]);
}
int lca(int u,int v){
if(D[u]>D[v]) swap(u,v);
for(int i=0;i<18;i++){
if( ((D[v]-D[u])>>i)&1) {
v = db[v][i];
}
}
if(u == v) return u;
for(int i=17;i>=0;i--){
if(db[u][i] != db[v][i]){
v = db[v][i];
u = db[u][i];
}
}
return db[v][0];
}
int calc(int u,int v){
if(D[u]>D[v]) swap(u,v);int x=u,y=v;
for(int i=0;i<18;i++){
if( ((D[v]-D[u])>>i)&1) {
v = db[v][i];
}
}
int c;
if(u == v) c=u;
else{
for(int i=17;i>=0;i--){
if(db[u][i] != db[v][i]){
v = db[v][i];
u = db[u][i];
}
}
c = db[v][0];
}
return D[x]+D[y]-2*D[c];
}
int main(){
cin >> n >> m;
for(int i=1;i<n;i++){
int a,b; scanf("%d%d",&a,&b);
e[a].pb(b);
e[b].pb(a);
}
cin >> s;
for(int i=0;i<n;i++){
if(s[i] == '1') T.pb(i+1);
}
priority_queue<P,vector<P>,greater<P> >que;
fill(ds,ds+100005,10000000);
for(int i=0;i<T.size();i++){
que.push(mp(0,T[i]));
ds[T[i]] = 0;
}
while(!que.empty()){
P p = que.top(); que.pop();
if(ds[p.sc] != p.fi) continue;
for(int i=0;i<e[p.sc].size();i++){
if(ds[e[p.sc][i]] > p.fi+1){
que.push(mp(p.fi+1,e[p.sc][i]));
ds[e[p.sc][i]] = p.fi+1;
}
}
}
memset(db,-1,sizeof(db));
dfs(1,-1);
for(int i=1;i<=n;i++){
d[i][0] = -D[i]+3*ds[i];
ee[i][0] = D[i]+3*ds[i];
}
for(int j=0;j<17;j++) {for(int i=1;i<=n;i++){
if(db[i][j] == -1){
db[i][j+1] = -1;
d[i][j+1] = d[i][j];
ee[i][j+1] = ee[i][j];
}
else{
db[i][j+1] = db[db[i][j]][j];
d[i][j+1] = min(d[i][j],d[db[i][j]][j]);
ee[i][j+1] = min(ee[i][j],ee[db[i][j]][j]);
}
}}
for(int i=0;i<m;i++){
int s,g; scanf("%d%d",&s,&g);
int v = lca(s,g);
int dist = calc(s,g);
int dist2 = calc(v,g);
int L = rd(s,v);
int R = re(g,v);
L += dist*2-dist2+D[v];
R += dist*2-dist2-D[v];
printf("%d\n",min(dist*2,min(L,R)));
}
}