题目描述:
n个点的树,选出k条道路,要求k在[L,U]之间,且道路的平均权值最大。(k不定)
n<=100000, 边权<=106
题目分析:
平均值最大比较常见的思路是二分,每条边减去mid后求最大值看是否大于等于0。
用 f [ u ] [ i ] f[u][i] f[u][i]表示以 u u u为根长度为 i i i的链的最大权值(加上了从 u u u到根的权值,方便计算)
我们需要在添加儿子 v v v的时候对于 v v v的每个深度,区间查询匹配的在 [ L , U ] [L,U] [L,U]之间的 u u u的深度的最大权值,来更新答案。还需要将儿子深度对应的权值更新父亲对应的权值。
注意到复杂度只跟深度有关,我们可以长链剖分,这样父亲和长儿子在dfs序上就是连续的,我们可以用线段树来存 f f f。这样每条长链只会在顶端将整条链遍历一次,遍历的复杂度是O(n),加上线段树就是O(nlogn)。
Code:
#include<cstdio>
#include<algorithm>
#define maxn 100005
#define LL long long
using namespace std;
const double inf = 1e20;
int n,L,U,mxd[maxn],dep[maxn],son[maxn],dfn[maxn],tim;
double dis[maxn],mx[maxn<<2],ans,Mid,tmp[maxn];
int fir[maxn],nxt[maxn<<1],to[maxn<<1],w[maxn<<1],tot;
inline void line(int x,int y,int z){nxt[++tot]=fir[x],fir[x]=tot,to[tot]=y,w[tot]=z;}
void dfs1(int u,int ff){
mxd[u]=dep[u]=dep[ff]+1;
for(int i=fir[u],v;i;i=nxt[i]) if((v=to[i])!=ff){
dfs1(v,u);mxd[u]=max(mxd[u],mxd[v]);
if(mxd[v]>mxd[son[u]]) son[u]=v;
}
}
void dfs2(int u){
dfn[u]=++tim;
if(son[u]) dfs2(son[u]);
for(int i=fir[u];i;i=nxt[i]) if(!dfn[to[i]]) dfs2(to[i]);
}
void build(int i,int l,int r){
mx[i]=-inf;
if(l==r) return;
int mid=(l+r)>>1;
build(i<<1,l,mid),build(i<<1|1,mid+1,r);
}
void insert(int i,int l,int r,int x,double d){
if(l==r) {mx[i]=max(mx[i],d);return;}
int mid=(l+r)>>1;
if(x<=mid) insert(i<<1,l,mid,x,d);
else insert(i<<1|1,mid+1,r,x,d);
mx[i]=max(mx[i<<1],mx[i<<1|1]);
}
double query(int i,int l,int r,int x,int y){
if(x<=l&&r<=y) return mx[i];
int mid=(l+r)>>1;double ret=-inf;
if(x<=mid) ret=max(ret,query(i<<1,l,mid,x,y));
if(y>mid) ret=max(ret,query(i<<1|1,mid+1,r,x,y));
return ret;
}
void solve(int u,int ff){
insert(1,1,n,dfn[u],dis[u]);
for(int i=fir[u];i;i=nxt[i]) if(to[i]==son[u]) dis[to[i]]=dis[u]+w[i]-Mid,solve(to[i],u);
for(int i=fir[u],v;i;i=nxt[i]) if((v=to[i])!=son[u]&&v!=ff){
dis[v]=dis[u]+w[i]-Mid,solve(v,u);
for(int j=0;j<=mxd[v]-dep[v];j++){
tmp[j]=query(1,1,n,dfn[v]+j,dfn[v]+j);
if(L-(j+1)<=mxd[u]-dep[u]&&U-(j+1)>=0)
ans=max(ans,query(1,1,n,dfn[u]+max(0,L-(j+1)),dfn[u]+min(mxd[u]-dep[u],U-(j+1)))+tmp[j]-2*dis[u]);
}
for(int j=0;j<=mxd[v]-dep[v];j++) insert(1,1,n,dfn[u]+j+1,tmp[j]);
}
if(L<=mxd[u]-dep[u]) ans=max(ans,query(1,1,n,dfn[u]+L,dfn[u]+min(mxd[u]-dep[u],U))-dis[u]);
}
int main()
{
scanf("%d%d%d",&n,&L,&U);
for(int i=1,x,y,z;i<n;i++) scanf("%d%d%d",&x,&y,&z),line(x,y,z),line(y,x,z);
dfs1(1,0),dfs2(1);
double l=0,r=1e6;
while(r-l>1e-4){
Mid=(l+r)/2;
build(1,1,n),ans=-inf,solve(1,0);
if(ans>=0) l=Mid;
else r=Mid;
}
printf("%.3f\n",l);
}
但是还没完。
平均值最大还有一种神乎其技的操作叫迭代,比二分快十倍。。
具体来说就是先随便设一个平均值mid,然后每条边减去这个值,再求出最大值以及其对应的长度。根据最大值和长度算出真正的平均值ans,然后令mid=ans,再重复这个过程,如果mid和ans的差值符合精度要求就可以停止了。
Code:
#include<cstdio>
#include<cctype>
#include<cmath>
#include<algorithm>
#define maxn 100005
#define LL long long
using namespace std;
char cb[1<<15],*cs,*ct;
#define getc() (cs==ct&&(ct=(cs=cb)+fread(cb,1,1<<15,stdin),cs==ct)?0:*cs++)
template<class T>inline void read(T &a){
char c;bool f=0;while(!isdigit(c=getc())) if(c=='-') f=1;
for(a=c-'0';isdigit(c=getc());a=a*10+c-'0'); if(f) a=-a;
}
const double inf = 1e20;
int n,m,L,R,mxd[maxn],dep[maxn],son[maxn],dfn[maxn],pos[maxn],tim;
double ans,mid,dis[maxn];
int fir[maxn],tot,Eson[maxn],tmp[maxn];
struct edge{int nxt,to,w;}e[maxn<<1];
inline void line(int x,int y,int z){e[++tot]=(edge){fir[x],y,z};fir[x]=tot;}
void dfs1(int u,int pre){
for(int i=fir[u],v;i;i=e[i].nxt) if((v=e[i].to)!=pre){
mxd[v]=dep[v]=dep[u]+1;
dfs1(v,u);mxd[u]=max(mxd[u],mxd[v]);
if(mxd[v]>mxd[son[u]]) son[u]=v,Eson[u]=i;
}
}
void dfs2(int u){
dfn[u]=++tim;
if(son[u]) dfs2(son[u]);
for(int i=fir[u];i;i=e[i].nxt) if(!dfn[e[i].to]) dfs2(e[i].to);
}
int id[maxn<<2];
inline void chkmax(int &i,int j){if(dis[i]<dis[j]) i=j;}
void build(int i,int l,int r){
id[i]=0;
if(l==r) {pos[l]=i;return;}
int mid=(l+r)>>1;
build(i<<1,l,mid),build(i<<1|1,mid+1,r);
}
void insert(int i,int l,int r,int x,int p){
chkmax(id[i],p);
if(l==r) return;
int mid=(l+r)>>1;
if(x<=mid) insert(i<<1,l,mid,x,p);
else insert(i<<1|1,mid+1,r,x,p);
}
int query(int i,int l,int r,int x,int y){
if(x>y) return 0;
if(x<=l&&r<=y) return id[i];
int mid=(l+r)>>1,ret=0;
if(x<=mid) chkmax(ret,query(i<<1,l,mid,x,y));
if(y>mid) chkmax(ret,query(i<<1|1,mid+1,r,x,y));
return ret;
}
inline void calc(int a,int b,int c){
if(!a||!b) return;
static int pts; pts = (dep[a]+dep[b]-2*dep[c]);
ans=max(ans,(dis[a]+dis[b]-2*dis[c]+mid*pts)/pts);
}
void solve(int u,int pre){
insert(1,1,n,dfn[u],u);
if(son[u]) dis[son[u]]=dis[u]+e[Eson[u]].w-mid, solve(son[u],u);
for(int i=fir[u],v;i;i=e[i].nxt) if((v=e[i].to)!=pre&&v!=son[u]){
dis[v]=dis[u]+e[i].w-mid, solve(v,u);
for(int j=mxd[v]-dep[u];j>=1;j--){
tmp[j]=id[pos[dfn[v]+j-1]];
calc(tmp[j],query(1,1,n,dfn[u]+max(L-j,0),dfn[u]+min(R-j,mxd[u]-dep[u])),u);
}
for(int j=mxd[v]-dep[u];j>=1;j--) insert(1,1,n,dfn[u]+j,tmp[j]);
}
calc(query(1,1,n,dfn[u]+L,dfn[u]+min(R,mxd[u]-dep[u])),u,u);
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("H.in","r",stdin);
#endif
read(n),read(L),read(R);
for(int i=1,x,y,z;i<n;i++) read(x),read(y),read(z),line(x,y,z),line(y,x,z);
dfs1(1,0),dfs2(1);
mid=500000; dis[0]=-inf;
while(1){
ans=-inf,build(1,1,n),solve(1,0);
if(fabs(mid-ans)<1e-3) break;
mid=ans;
}
printf("%.3f",ans);
}