思路:
先上一份ACode:(具体解释思路在下面)
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e5+10,M=2*N;
int n,k,x,fa[N],sum;
struct E {
int u,v,w;
} e[M];
int head[N],cnt;
bool vis[N];
bool cmp(E i,E j) {
return i.w>j.w;
}
void init() {
for(int i=1; i<=n; i++) fa[i]=i;
}
int find(int x) {
return fa[x]==x?x:fa[x]=find(fa[x]);
}
void fun() {
for(int i=1; i<=n-1; i++) {
int f1=find(e[i].u);
int f2=find(e[i].v);
if(!(vis[f1]&&vis[f2])&&f1!=f2) {
//然后合并的时候必须要两个集合不是都有敌人的城市
fa[f2]=f1;
vis[f1]=(vis[f1]||vis[f2]);//如果一个正常节点连接上了敌人节点
//那么这个正常节点也变为敌人节点,因为其他敌人点可以通过该正常点
//与其他敌人点相连
//如果连接到了被占领了的,另外一个点要被"假占领"(因为合并了,所只要改父亲占领状态就好了)
//不然的话要是另外一个占领的点连接到了这个点那么就有两个被真占领点联通了
sum-=e[i].w;
}
}
}
void solve() {
cin>>n>>k;
init();//初始化父节点
for(int i=1; i<=k; i++) { //标记被占领的城市
cin>>x;
vis[x]=true;
}
for(int i=1; i<=n-1; i++) { //建图
cin>>e[i].u>>e[i].v>>e[i].w;
sum+=e[i].w;
}
sort(e+1,e+n,cmp);//先令所有道路摧毁,
//再按大到小修建两个不都是敌人的节点
fun();
cout<<sum<<"\n";
}
signed main() {
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
int tt=1;
//cin>>tt;
while(tt--) {
solve();
}
return 0;
}
ACcode:
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N=1e5+10,M=2*N;
int n,k,x,fa[N],sum;
struct E {
int u,v,w;
} e[M];
int head[N],cnt;
bool vis[N];
bool cmp(E i,E j) {
return i.w>j.w;
}
void init() {
for(int i=1; i<=n; i++) fa[i]=i;
}
int find(int x) {
return fa[x]==x?x:fa[x]=find(fa[x]);
}
void fun() {
for(int i=1; i<=n-1; i++) {
int f1=find(e[i].u);
int f2=find(e[i].v);
if(!(vis[f1]&&vis[f2])&&f1!=f2) {
//然后合并的时候必须要两个集合不是都有敌人的城市
fa[f2]=f1;
vis[f1]=(vis[f1]||vis[f2]);//如果一个正常节点连接上了敌人节点
//那么这个正常节点也变为敌人节点,因为其他敌人点可以通过该正常点
//与其他敌人点相连
//如果连接到了被占领了的,另外一个点要被"假占领"(因为合并了,所只要改父亲占领状态就好了)
//不然的话要是另外一个占领的点连接到了这个点那么就有两个被真占领点联通了
sum-=e[i].w;
}
}
}
void solve() {
cin>>n>>k;
init();//初始化父节点
for(int i=1; i<=k; i++) { //标记被占领的城市
cin>>x;
vis[x]=true;
}
for(int i=1; i<=n-1; i++) { //建图
cin>>e[i].u>>e[i].v>>e[i].w;
sum+=e[i].w;
}
sort(e+1,e+n,cmp);//先令所有道路摧毁,
//再按大到小修建两个不都是敌人的节点
fun();
cout<<sum<<"\n";
}
signed main() {
ios::sync_with_stdio(false),cin.tie(0),cout.tie(0);
int tt=1;
//cin>>tt;
while(tt--) {
solve();
}
return 0;
}
over