分析
题目给我们的是一颗树,如果不考虑改造,那么答案就为树上链的最大值,用倍增来维护。对于每一个询问链,必须要改造最大值才能对答案有贡献,其贡献为
min(L,链上最大值−次大值)
所以我们用倍增维护最大值与次大值,把贡献标记在最大值所在的边上(若最大值有多个,则对答案无影响),最后找出贡献最大的边即可。
代码
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
int n,q,l,x,y,z,tot,last[100005],nex[200005],to[200005],w[200005],dep[100005];
int fa[100005][20],f[100005][20],g[100005][20],p[100005][20];
long long s[100005],ans,ll;
void read(int &n)
{
int x=0,w=1;
char ch=getchar();
while (ch!='-' && (ch<'0' || ch>'9')) ch=getchar();
if (ch=='-') w=-1,ch=getchar();
while (ch<='9' && ch>='0') x=(x<<1)+(x<<3)+ch-48,ch=getchar();
n=w*x;
}
void add(int x,int y,int z)
{
to[++tot]=y;
w[tot]=z;
nex[tot]=last[x];
last[x]=tot;
}
void dg(int x,int y)
{
fa[x][0]=y;
dep[x]=dep[y]+1;
for (int i=last[x];i;i=nex[i])
if (to[i]!=y)
{
f[to[i]][0]=w[i];
p[to[i]][0]=(i+1)/2;
dg(to[i],x);
}
}
void lca(int x,int y)
{
if (dep[x]<dep[y]) swap(x,y);
int k=0,h=-2147483647,u=-2147483647,v;
for (int m=dep[x]-dep[y];m;m>>=1)
{
if (m&1)
{
if (f[x][k]>h)
{
u=max(u,h),u=max(u,g[x][k]);
h=f[x][k];
v=p[x][k];
}
else u=max(u,f[x][k]);
x=fa[x][k];
}
k++;
}
k=0;
while (x!=y)
if (!k || fa[x][k]!=fa[y][k])
{
if (f[x][k]>h)
{
u=max(u,h),u=max(u,g[x][k]);
h=f[x][k];
v=p[x][k];
}
else u=max(u,f[x][k]);
if (f[y][k]>h)
{
u=max(u,h),u=max(u,g[y][k]);
h=f[y][k];
v=p[y][k];
}
else u=max(u,f[y][k]);
x=fa[x][k];
y=fa[y][k];
k++;
}
else k--;
if (u<=h-l) s[v]+=l; else s[v]+=h-u;
ans+=h;
}
int max(int a,int b)
{
return a>b?a:b;
}
int main()
{
freopen("carry.in","r",stdin);
freopen("carry.out","w",stdout);
read(n),read(q),read(l);
for (int i=1;i<n;i++)
{
read(x),read(y),read(z);
add(x,y,z),add(y,x,z);
}
memset(f,128,sizeof(f));
memset(g,128,sizeof(g));
dg(1,0);
for (int j=1;j<=17;j++)
for (int i=1;i<=n;i++)
{
fa[i][j]=fa[fa[i][j-1]][j-1];
f[i][j]=max(f[i][j-1],f[fa[i][j-1]][j-1]);
g[i][j]=max(g[i][j-1],g[fa[i][j-1]][j-1]);
if (f[i][j]==f[i][j-1])
g[i][j]=max(g[i][j],f[fa[i][j-1]][j-1]),p[i][j]=p[i][j-1];
else g[i][j]=max(g[i][j],f[i][j-1]),p[i][j]=p[fa[i][j-1]][j-1];
}
for (int i=1;i<=q;i++)
{
read(x),read(y);
if (x!=y) lca(x,y);
}
int hu;
for (int i=1;i<n;i++) if (ll<s[i]) ll=s[i],hu=i;
if (ans-ll<ans) ans=ans-ll;
printf("%lld",ans);
return 0;
}