思路特别简单
得出每条边的贡献,然后放大顶堆逐个减去再push直到不满足sum>S即可
注意大顶堆里数的比较方法(这里WA了三次):
struct node2
{
ll val;
ll num;
node2(ll x,ll y):val(x),num(y){}
bool operator < (const node2& node2) const{
return val*num-(val/2)*num < node2.val*node2.num-(node2.val/2)*node2.num;
}
};
priority_queue<node2>que;
AC代码
#include <iostream>
#include <stdio.h>
#include <algorithm>
#include <map>
#include <queue>
#define ll long long
#include<set>
#include<string.h>
#include<math.h>
#include<istream>
using namespace std;
#define inf 0x3f3f3f3f
inline ll read() {
ll k = 0, f = 1; char ch = getchar();
while (ch < '0' || ch>'9') { if (ch == '-')f = -1; ch = getchar(); }
while (ch >= '0' && ch <= '9') { k = k * 10 + ch - '0'; ch = getchar(); }return k * f;
}
#define mod 998244353
struct node
{
ll to,w;
};
vector<node>vec[100005];
ll num[100005];
ll du[100005];
struct node2
{
ll val;
ll num;
node2(ll x,ll y):val(x),num(y){}
bool operator < (const node2& node2) const{
return val*num-(val/2)*num < node2.val*node2.num-(node2.val/2)*node2.num;
}
};
priority_queue<node2>que;
bool vis[100005];
ll sum=0;
void dfs(ll x)
{
vis[x]=1;
for(int i=0;i<vec[x].size();i++)
{
node now=vec[x][i];
if(!vis[now.to])
{
dfs(now.to);
num[x]+=num[now.to];
sum+=num[now.to]*now.w;
que.push(node2(now.w,num[now.to]));
}
}
}
int main()
{
ll t=read();
while(t--)
{
ll n=read(),S=read();
sum=0;
while(!que.empty())que.pop();
for(int i=1;i<=n;i++)du[i]=0,vec[i].clear();
for(int i=0;i<n-1;i++)
{
ll u=read(),v=read(),w=read();
node temp;
temp.to=v;
temp.w=w;
vec[u].push_back(temp);
node temp2;
temp2.to=u;
temp2.w=w;
vec[v].push_back(temp2);
du[u]++;
du[v]++;
}
for(int i=1;i<=n;i++)
{
vis[i]=0;
if(du[i]==1)num[i]=1;
else num[i]=0;
}
dfs(1);
ll ans=0;
while(sum>S)
{
ll tval=que.top().val;
ll tnum=que.top().num;
que.pop();
sum-=(tval*tnum-(tval/2)*tnum);
ans++;
que.push(node2(tval/2,tnum));
}
cout<<ans<<'\n';
}
}