背景:
好久之前的坑…
题目传送门:
https://www.luogu.org/problemnew/show/P2680
题意:
有一棵树,现在让你选取一条边的距离改为
0
0
0,求树上
m
m
m个点对到达的用时(走单位
1
1
1的距离用时为
1
1
1)。
思路:
谨慎阅读,非正解+卡常。
显然我们不能暴力。于是我们考虑将哪一条边的权值改为
0
0
0的问题求最小值就变成了二分最后的时间(有单调性)。
具体来说就是我们先算出每一个点对的距离,记录下来。
在二分
c
h
e
c
k
check
check的时候将距离大于
m
i
d
mid
mid的点对所经过的所有边都加上
1
1
1(在树剖的树上新开一个变量来记录),并记录下一共有多少个点对不符合要求(在下面的代码中记作
g
e
s
h
u
geshu
geshu),因为这些路径上至少有一条边的路径要改为
0
0
0。
最后再看看每一个树上的每一条边被经过的次数(只是不合法的点对贡献的)是否大于等于
g
e
s
h
u
geshu
geshu,此时这一条边的权值必须改为
0
0
0,根据所有点对的最长距离-这条边的权值(算出的就是把这条边的权值改为
0
0
0,此时的最大用时)与
m
i
d
mid
mid的大小关系,相应的改变二分的边界即可。
T i p s Tips Tips(卡常):先排序预处理的每一个点对的距离能节省一部分时间(从大到小降序),且有数据范围 t i ≤ 1000 t_i≤1000 ti≤1000可知二分的边界为 m a x ( 0 , 每 一 个 点 对 的 最 大 距 离 − 1000 ) ∼ 每 一 个 点 对 的 最 大 距 离 max(0,每一个点对的最大距离-1000) \sim每一个点对的最大距离 max(0,每一个点对的最大距离−1000)∼每一个点对的最大距离(相当于减去了某一条边的长度)。
时间复杂度:
Θ
(
n
l
o
g
3
n
+
大
常
数
)
\Theta(nlog^3n+大常数)
Θ(nlog3n+大常数)
代码:
#include<cstdio>
#include<cstring>
#include<algorithm>
#define R register
#define I inline
using namespace std;
int n,m,q,u=0,len=0,ANS;
struct node1{int x,y,z,next;} a[600010];
struct node2{int x,y,z;} b[600010];
struct node3{int l,r,lc,rc,n,d[2],lazy[2];} tr[600010];
struct node4{int x,y,ans;} c[600010];
int last[600010],p[600010],tot[600010],son[600010],fa[600010],dep[600010],ys[600010],top[600010];
I void ins(int x,int y,int z)
{
a[++len]=(node1){x,y,z,last[x]};last[x]=len;
}
I char cn()
{
static char buf[1000010],*p1=buf,*p2=buf;
return p1==p2&&(p2=(p1=buf)+fread(buf,1,1000000,stdin),p1==p2)?EOF:*p1++;
}
I void read(int &x)
{
x=0;int f1=1;char ch=cn();
while(ch<'0'||ch>'9'){if(ch=='-')f1=-1;ch=cn();}
while(ch>='0'&&ch<='9')x=x*10+(ch-'0'),ch=cn();
x*=f1;
}
void build(int l,int r)
{
int now=++len;
tr[now]=(node3){l,r,-1,-1,r-l+1,0,0,0,0};
if(l<r)
{
int mid=(l+r)>>1;
tr[now].lc=len+1; build(l,mid);
tr[now].rc=len+1; build(mid+1,r);
}
}
I void update(int id,int now)
{
if(tr[now].lazy[id])
{
int lc=tr[now].lc,rc=tr[now].rc;
if(lc!=-1) tr[lc].d[id]+=tr[lc].n*tr[now].lazy[id],tr[lc].lazy[id]+=tr[now].lazy[id];
if(rc!=-1) tr[rc].d[id]+=tr[rc].n*tr[now].lazy[id],tr[rc].lazy[id]+=tr[now].lazy[id];
tr[now].lazy[id]=0;
}
}
void change(int id,int now,int l,int r,int k)
{
update(id,now);
if(tr[now].l==l&&tr[now].r==r)
{
tr[now].d[id]+=tr[now].n*k;
tr[now].lazy[id]+=k;
return;
}
int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)>>1;
if(mid+1<=l) change(id,rc,l,r,k);
else if(r<=mid) change(id,lc,l,r,k);
else change(id,lc,l,mid,k),change(id,rc,mid+1,r,k);
tr[now].d[id]=tr[lc].d[id]+tr[rc].d[id];
}
int findsum(int id,int now,int l,int r)
{
update(id,now);
if(tr[now].l==l&&tr[now].r==r) return tr[now].d[id];
int lc=tr[now].lc,rc=tr[now].rc,mid=(tr[now].l+tr[now].r)>>1;
if(mid+1<=l) return findsum(id,rc,l,r);
else if(r<=mid) return findsum(id,lc,l,r);
else return findsum(id,lc,l,mid)+findsum(id,rc,mid+1,r);
}
void dfs1(int x)
{
tot[x]=1;
son[x]=0;
for(R int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(y!=fa[x])
{
fa[y]=x;
dep[y]=dep[x]+1;
dfs1(y);
if(tot[son[x]]<tot[y]) son[x]=y;
tot[x]+=tot[y];
}
}
}
void dfs2(int x,int tp)
{
ys[x]=++u;
top[x]=tp;
if(son[x]) dfs2(son[x],tp);
for(R int i=last[x];i;i=a[i].next)
{
int y=a[i].y;
if(y!=son[x]&&y!=fa[x]) dfs2(y,y);
}
}
I void add(int x,int y)
{
int tx=top[x],ty=top[y];
while(tx!=ty)
{
if(dep[tx]>dep[ty]) swap(x,y),swap(tx,ty);
change(1,1,ys[ty],ys[y],1);
y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y]) swap(x,y);
if(x!=y) change(1,1,ys[x]+1,ys[y],1);
}
I int solve(int id,int x,int y)
{
if(x==y) return 0;
int tx=top[x],ty=top[y],ans=0;
while(tx!=ty)
{
if(dep[tx]>dep[ty]) swap(x,y),swap(tx,ty);
ans+=findsum(id,1,ys[ty],ys[y]);
y=fa[ty];ty=top[y];
}
if(dep[x]>dep[y]) swap(x,y);
return ans+findsum(id,1,ys[x],ys[y])-findsum(id,1,ys[x],ys[x]);
}
I bool check(int x)
{
int geshu=0;
for(R int i=1;i<=(n<<1);i++)
tr[i].d[1]=0,tr[i].lazy[1]=0;
for(R int i=1;i<=m;i++)
if(c[i].ans>x) add(c[i].x,c[i].y),geshu++; else break;
for(R int i=1;i<n;i++)
if(findsum(1,1,max(ys[b[i].x],ys[b[i].y]),max(ys[b[i].x],ys[b[i].y]))>=geshu) return c[1].ans-b[i].z<=x;
return false;
}
bool cmp2(node2 x,node2 y)
{
return x.z>y.z;
}
bool cmp4(node4 x,node4 y)
{
return x.ans>y.ans;
}
int main()
{
int l=0,r=0,mid;
read(n),read(m);
for(R int i=1;i<n;i++)
read(b[i].x),read(b[i].y),read(b[i].z);
sort(b+1,b+n-1+1,cmp2);
for(R int i=1;i<n;i++)
ins(b[i].x,b[i].y,b[i].z),ins(b[i].y,b[i].x,b[i].z);
dep[1]=1;
fa[1]=0;
dfs1(1);
dfs2(1,1);
len=0;
build(1,n);
for(R int i=1;i<n;i++)
change(0,1,max(ys[b[i].x],ys[b[i].y]),max(ys[b[i].x],ys[b[i].y]),b[i].z);
for(R int i=1;i<=m;i++)
{
read(c[i].x),read(c[i].y);
c[i].ans=solve(0,c[i].x,c[i].y);
r=max(r,c[i].ans);
}
sort(c+1,c+m+1,cmp4);
l=max(0,r-1000);
while(l<=r)
{
mid=(l+r)>>1;
if(check(mid)) ANS=mid,r=mid-1; else l=mid+1;
}
printf("%d",ANS);
}