考场上自己yy出来的做法.....
Code:
#include<cstdio>
#include<algorithm>
#include<queue>
#include<vector>
#include<string>
using namespace std;
void setIO(string a){
freopen((a+".in").c_str(),"r",stdin);
freopen((a+".out").c_str(),"w",stdout);
}
void shutIO(){
fclose(stdin);
fclose(stdout);
}
#define ll long long
#define maxn 200009
struct Node{
ll dist;
int u;
Node(ll dist=0,int u=0):dist(dist),u(u){}
bool operator<(Node a) const{
return a.dist>dist;
}
};
priority_queue<Node>Q;
int p[maxn], tag[maxn],n,m,k,cnt,val[maxn],head[maxn],to[maxn<<1],nex[maxn<<1];
ll sumv[maxn];
int find(int x){
return p[x]==x?x:p[x]=find(p[x]);
}
void merge(int a,int b){
int x=find(a),y=find(b);
if(x==y) return;
p[x]=y;
}
void addedge(int u,int v){
nex[++cnt]=head[u],head[u]=cnt,to[cnt]=v;
}
bool check(int a){
int x=find(a);
if(tag[x]) return false;
tag[x]=1;
return true;
}
void dfs(int u,int fa){
sumv[u]=(ll)val[u];
int flag=0;
ll MAX=0;
for(int v=head[u];v;v=nex[v]){
if(to[v]==fa) continue;
dfs(to[v],u);
if(sumv[to[v]]>MAX) MAX=sumv[to[v]], flag=to[v];
}
if(flag){
sumv[u]+=MAX, merge(u,flag);
for(int v=head[u];v;v=nex[v]){
if(to[v]==fa) continue;
if(to[v]!=flag) Q.push(Node(sumv[to[v]],to[v]));
}
}
}
int main(){
//setIO("game");
scanf("%d%d",&n,&k);
for(int i=1;i<=n;++i) scanf("%d",&val[i]);
for(int i=1;i<=n;++i) p[i]=i;
for(int i=1;i<n;++i){
int a,b;
scanf("%d%d",&a,&b);
addedge(a,b),addedge(b,a);
}
dfs(1,0);
Q.push(Node(sumv[1],1));
int cur=0;
ll fin=0;
while(!Q.empty()&&cur<k){
Node a=Q.top();Q.pop();
if(check(a.u))fin+=a.dist,cur+=1;
}
printf("%lld",fin);
//shutIO();
return 0;
}