模板题,利用倍增维护最小生成树上的最大和次大边,然后枚举不在最小生成树中的边,求解
#include <bits/stdc++.h>
#define inf 0x7fffffff
#define ll long long
#define int long long
//#define double long double
//#define double long long
#define re register int
//#define void inline void
#define eps 1e-8
//#define mod 1e9+7
#define ls(p) p<<1
#define rs(p) p<<1|1
#define pi acos(-1.0)
#define pb push_back
#define mk make_pair
#define P pair < int , int >
using namespace std;
const int mod= 998244353;
//const int inf=1e18;
const int M=1e8;
const int N=4e5+5;//??????.???? 4e8
struct node
{
int ver,edge,next;
}e[N];
struct nocr
{
int x,y,z;
}tr[N];
int tot,head[N],v[N];
int dep[N],fa[N],f[N][64],g[N][64],h[N][64];//g最大,h次大
int n,m,t,sum;
int d[N];
bool cmp(nocr i,nocr j)
{
return i.z<j.z;
}
int get(int x)
{
if(fa[x]==x) return x;
return fa[x]=get(fa[x]);
}
void add(int x,int y,int z)
{
e[++tot].ver=y;
e[tot].edge=z;
e[tot].next=head[x];
head[x]=tot;
}
void addedge(int x,int y,int z)
{
add(x,y,z);add(y,x,z);
}
void dfs(int x,int pre,int w)
{
d[x]=d[pre]+1,f[x][0]=pre,g[x][0]=w,h[x][0]=-1e18;
for(re i=1;i<=t;i++)
{
f[x][i]=f[f[x][i-1]][i-1];
g[x][i]=max(g[x][i-1],g[f[x][i-1]][i-1]);
h[x][i]=max(h[x][i-1],h[f[x][i-1]][i-1]);
if(g[x][i-1]>g[f[x][i-1]][i-1]) h[x][i]=max(h[x][i],g[f[x][i-1]][i-1]);
else if(g[x][i-1]<g[f[x][i-1]][i-1]) h[x][i]=max(h[x][i],g[x][i-1]);
// else h[x][i]=max(h[x][i-1],h[f[x][i-1]][i-1]);
}
for(re i=head[x];i;i=e[i].next)
{
int y=e[i].ver;
int z=e[i].edge;
if(y==pre) continue;
dfs(y,x,z);
}
}
int lca(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(re i=t;i>=0;i--) if(d[f[y][i]]>=d[x]) y=f[y][i];
if(x==y) return x;
for(re i=t;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
return f[x][0];
}
int get(int x,int y,int ma)
{
int ans=-1e18;
for(re i=t;i>=0;i--) if(d[f[x][i]]>=d[y])
{
if(ma!=g[x][i]) ans=max(ans,g[x][i]);
else ans=max(ans,h[x][i]);
x=f[x][i];
}
return ans;
}
void print()
{
int ans=1e18;
for(re i=1;i<=m;i++)
{
if(v[i]) continue;
int x=tr[i].x,y=tr[i].y,z=tr[i].z;
int LCA=lca(x,y);
int ux=get(x,LCA,z),uy=get(y,LCA,z);
ans=min(ans,sum-max(ux,uy)+z);
}
// for(re i=1;i<=m;i++)
// {
// if(v[i]) continue;
// int x
// }
cout<<ans<<endl;
}
void solve()
{
cin>>n>>m;
t=(int)(log(n)/log(2)+1);
for(re i=1;i<=n;i++) fa[i]=i;
for(re i=1;i<=m;i++) scanf("%lld%lld%lld",&tr[i].x,&tr[i].y,&tr[i].z);
sort(tr+1,tr+m+1,cmp);
for(re i=1;i<=m;i++)
{
int xx=get(tr[i].x);
int yy=get(tr[i].y);
if(xx==yy) continue;
addedge(tr[i].x,tr[i].y,tr[i].z);
sum+=tr[i].z;
fa[xx]=yy;
v[i]=1;
}
dfs(1,1,0);
print();
}
signed main()
{
int T=1;
// cin>>T;
for(int index=1;index<=T;index++)
{
solve();
// puts("");
}
return 0;
}
/*
*/