某学长:这题啊,树剖啊,裸的,结果当我写了100+树剖,猛然发现,好像这个树剖除了求lca就没有任何卵用了,so,为毛不用倍增,不过还是有用,洛谷上时限卡的紧,倍增求lca根本不行,T的不要不要的,好在现在管理员把时限改过来了,可以放心使用了
思路嘛,就是二分答案,然后找出所有比二分出答案大的路径,那么这些路径一定是需要改进的对吧,也就是说需要删除的那一条边一定是这些路径的一个交集,所以差分记录每一个边用了几次,求得cnt==大于二分的边数的边,然后检验除去他以后答案是否合法,即 len>=Max-x(二分结果)就好了,复杂度Onlogn2
洛谷上的代码(为了卡常数,优化了一堆,所以看起来比较鬼畜,比如说折半查找qaq,别怕,还有正常的):
#include<cstdio>
#include<cstring>
#include<iostream>
#define ls u<<1,l,mid
#define rs u<<1|1,mid+1,r
#define maxn 300020
#define mmax(a,b) (a>b?a:b)
#define sswap(a,b) (a^=b^=a^=b)
#define adde(a,b,c) (e[tot].v=b,e[tot].next=head[a],e[tot].w=c,head[a]=tot++)
using namespace std;
int n,m,cnt,head[maxn],tot,top[maxn],son[maxn],f[maxn],size[maxn],dis[maxn*2];
int w[maxn],h[maxn],ans,Max,sum[maxn],A[maxn],B[maxn],d[maxn],C[maxn];
struct edge{int v,next,w;}e[maxn*2];
bool ok;int pos1,pos2;
void read(int& x){
char c=getchar();x=0;
for(;c>'9'||c<'0';c=getchar());
for(;c>='0'&&c<='9';c=getchar())x=x*10+c-'0';
}
void dfs1(int u,int fa){
f[u]=fa,h[u]=h[fa]+1,size[u]=1;
for(int v,i=head[u];~i;i=e[i].next){
v=e[i].v;if(v==fa)continue;
d[v]=d[u]+e[i].w;
w[v]=e[i].w;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]])son[u]=v;
}
}
void dfs2(int u,int fa,int tt){
top[u]=tt;
if(son[u])dfs2(son[u],u,tt);
for(int v,i=head[u];i!=-1;i=e[i].next){
v=e[i].v;if(v==fa||v==son[u])continue;
dfs2(v,u,v);
}
}
inline int lca(int a,int b){
while(top[a]!=top[b]){
if(h[top[a]]>h[top[b]])sswap(a,b);
b=f[top[b]];
}
if(a==b)return a;
if(h[a]>h[b])sswap(a,b);
return a;
}
void dfs(int u,int fa){
for(int v,i=head[u];i!=-1;i=e[i].next ){
v=e[i].v;if(v==fa)continue;
dfs(v,u);
sum[u]+=sum[v];
}
if(sum[u]==pos1&&w[u]>=pos2)ok=true;
}
bool check(int x){
for(int i=1;i<=n;i++)sum[i]=0;
int rec=0,cnt=0,Max=0,y;
for(int i=1;i<=(1+m)>>1;i++){
if(dis[i]>x){
cnt++;
Max=mmax(Max,dis[i]-x);
sum[A[i]]++,sum[B[i]]++,sum[C[i]]-=2;
}
y=m-i+1;
if(dis[y]>x){
cnt++;
Max=mmax(Max,dis[y]-x);
sum[A[y]]++,sum[B[y]]++,sum[C[y]]-=2;
}
}
ok=false,pos1=cnt,pos2=Max;
dfs(1,1);
return ok;
}
int main(){
memset(head,-1,sizeof(head));
read(n),read(m);
for(int a,b,c,i=1;i<n;i++){
read(a),read(b),read(c);
adde(a,b,c);adde(b,a,c);
}
dfs1(1,1);dfs2(1,1,1);
int l=0,r=0;
for(int a,b,c,i=1;i<=m;i++){
read(a),read(b);
A[i]=a,B[i]=b;
c=lca(a,b);
C[i]=c;
dis[i]=d[a]+d[b]-2*d[c];
r=mmax(r,dis[i]);
}
int mid;
while(l<r){
mid=l+r>>1;
if(check(mid))r=mid;
else l=mid+1;
}
printf("%d",l);
return 0;
}
BZOJ:
#include<cstdio>
#include<cstring>
#include<iostream>
#define ls u<<1,l,mid
#define rs u<<1|1,mid+1,r
#define maxn 300020
using namespace std;
int n,m,cnt,head[maxn*2],tot,top[maxn],son[maxn],f[maxn],size[maxn],dis[maxn*2];
int w[maxn],h[maxn],ans,Max,sum[maxn],A[maxn],B[maxn],d[maxn],C[maxn];
struct edge{int v,next,w;}e[maxn*2];
void adde(int a,int b,int c){e[tot].v=b,e[tot].next=head[a],e[tot].w=c;head[a]=tot++;}
void read(int& x){
char c=getchar();x=0;
for(;c>'9'||c<'0';c=getchar());
for(;c>='0'&&c<='9';c=getchar())x=x*10+c-'0';
}
void dfs1(int u,int fa){
f[u]=fa,h[u]=h[fa]+1,size[u]=1;
for(int v,i=head[u];~i;i=e[i].next){
v=e[i].v;if(v==fa)continue;
d[v]=d[u]+e[i].w;
w[v]=e[i].w;
dfs1(v,u);
size[u]+=size[v];
if(size[v]>size[son[u]])son[u]=v;
}
}
void dfs2(int u,int fa,int tt){
top[u]=tt;
if(son[u])dfs2(son[u],u,tt);
for(int v,i=head[u];i!=-1;i=e[i].next){
v=e[i].v;if(v==fa||v==son[u])continue;
dfs2(v,u,v);
}
}
int lca(int a,int b){
while(top[a]!=top[b]){
if(h[top[a]]>h[top[b]])swap(a,b);
b=f[top[b]];
}
if(a==b)return a;
if(h[a]>h[b])swap(a,b);
return a;
}
void dfs(int u,int fa){
for(int v,i=head[u];i!=-1;i=e[i].next ){
v=e[i].v;if(v==fa)continue;
dfs(v,u);
sum[u]+=sum[v];
}
}
bool check(int x){
for(int i=1;i<=n;i++)sum[i]=0;
int rec=0,cnt=0,Max=0;
for(int i=1;i<=m;i++)
if(dis[i]>x){
cnt++;
Max=max(Max,dis[i]-x);
sum[A[i]]++,sum[B[i]]++,sum[C[i]]-=2;
}
dfs(1,1);
for(int i=1;i<=n;i++){
if(sum[i]==cnt&&w[i]>=Max)return true;
}
return false;
}
int main(){
memset(head,-1,sizeof(head));
read(n),read(m);
for(int a,b,c,i=1;i<n;i++){
read(a),read(b),read(c);
adde(a,b,c),adde(b,a,c);
}
dfs1(1,1);dfs2(1,1,1);
int l=0,r=0;
for(int a,b,c,i=1;i<=m;i++){
read(a),read(b);
A[i]=a,B[i]=b;
c=lca(a,b);
C[i]=c;
dis[i]=d[a]+d[b]-2*d[c];
r=max(r,dis[i]);
}
while(l<r){
int mid=l+r>>1;
if(check(mid))r=mid;
else l=mid+1;
}
printf("%d",l);
return 0;
}