每日一题 P2680 [NOIP2015 提高组] 运输计划 LCA 树上差分 二分答案
紫题不愧是紫题,有一定难度。想到了二分答案但是没想清楚check函数怎么写。学到一点:用树上差分可以O(N)找出树上一些链的公共边,这样就比较好check了。代码写的比较垃圾,常数巨大。
#include <bits/stdc++.h>
#define MAXN 600005
using namespace std;
struct EDGE
{
int to,next,w;
} edge[MAXN];
int head[MAXN],ptr;
void add_edge(int u,int v,int w)
{
edge[++ptr].to=v;
edge[ptr].next=head[u];
edge[ptr].w=w;
head[u]=ptr;
}
void add(int u,int v,int w)
{
add_edge(u,v,w);
add_edge(v,u,w);
}
int dep[MAXN],f[MAXN][21],dlen[MAXN],cnt[MAXN];
int n,m;
void dfs(int now,int fa)
{
for(int p=head[now]; p; p=edge[p].next)
{
int to=edge[p].to;
if(to==fa) continue;
dep[to]=dep[now]+1;//计算深度
dlen[to]=dlen[now]+edge[p].w;
f[to][0]=now;
dfs(to,now);
}
}
void dp()
{
//预处理 f[u][i] u结点往上走2^i次
for(int i=1; (1<<i)<=n; i++)
for(int u=1; u<=n; u++)
f[u][i]=f[f[u][i-1]][i-1];//u结点往上走2^i次等于网上走两个2^(i-1)
}
int lca(int x,int y)
{
int p,t;
if(dep[x]<dep[y])
swap(x,y);//保证x比y大
for(p=0; (1<<p)<=dep[x]; p++);//算出dp最多往上走2的多少次;
for(t=--p; t>=0; t--)//从x走到y同一层
if(dep[x]-(1<<t)>=dep[y])//不超过的话就往上走
x=f[x][t];
if(x==y) return x;
for(t=p; t>=0; t--)//x和y一起走 不相等就往上走
if(f[x][t]!=f[y][t])
x=f[x][t],y=f[y][t];
return f[x][0];
}
int getlen(int x,int y,int lca)
{
return dlen[x]+dlen[y]-2*dlen[lca];
}
void addl(int x,int y,int lca)
{
cnt[x]++,cnt[y]++,cnt[lca]-=2;
}
struct NODE
{
int x,y,lca,len;
bool operator < (const NODE &t)const
{
return len<t.len;
}
};
vector<NODE> v;
int tot,maxn;
void dfs2(int now,int fa,int fe)
{
for(int p=head[now];p;p=edge[p].next)
{
int to=edge[p].to;
if(to==fa)continue;
dfs2(to,now,p);
}
cnt[fa]+=cnt[now];
if(cnt[now]==tot)maxn=max(maxn,edge[fe].w);
}
bool check(int t)
{
maxn=0;
memset(cnt,0,sizeof(int)*(n+1));
if(t>=v[v.size()-1].len) return 1;
int pos=upper_bound(v.begin(),v.end(),NODE{0,0,0,t})-v.begin();
tot=v.size()-pos;
for(int i=pos;i<v.size();i++)
addl(v[i].x,v[i].y,v[i].lca);
dfs2(1,0,0);
return v[v.size()-1].len-maxn<=t;
}
signed main()
{
cin>>n>>m;
for(int i=1; i<n; i++)
{
int a,b,w;
cin>>a>>b>>w;
add(a,b,w);
}
dfs(1,0);dp();
for(int i=1;i<=m;i++)
{
int a,b;cin>>a>>b;
int tmp=lca(a,b);
v.push_back(NODE{a,b,tmp,getlen(a,b,tmp)});
}
sort(v.begin(),v.end());
int L=0,R=300000005;
while(L<R)
{
int mid=(L+R)>>1;
if(check(mid))R=mid;
else L=mid+1;
}
cout<<L;
}