https://codeforces.com/contest/1399/problem/E2
把每条边的{val,w,sum}求出来放进优先队列,c=1的维护一个,c2=的维护一个,val为这条边操作一次能减少多少值,sum为这条边经过的叶子数量,w为当前边权
那么我们肯定取大的好,d1=q[1].top(),q[1].pop(),d12为消除了d1后,c1=1中能拿出的最大的,所以要把剩下的q[1].top和再次操作d1比较一下取较大,d2=q[2].top()
然后如果d1操作一次直接<=S了,那么直接跑路,否则比较一下d1+d12<d2,那么说明d2消1次比d1,d12消两次还多,那么就消d2,否则消d1
#include<bits/stdc++.h>
#define pb push_back
using namespace std;
typedef long long ll;
const int maxl=3e5+10;
ll n,m,cas,k,cnt,ans;
ll s,tot;
ll son[maxl];
struct ed{ll to,w,c;};
vector<ed> e[maxl];
bool in[maxl];
struct edg
{
ll val;ll w,sum;
bool operator < (const edg &b)const
{
return val<b.val;
}
};
priority_queue<edg> q[3];
inline void dfs(int u,int fa)
{
int v;son[u]=0;bool flag=false;
for(ed ee:e[u])
{
v=ee.to;
if(v==fa) continue;
dfs(v,u);flag=true;
son[u]+=son[v];
q[ee.c].push(edg{1ll*(ee.w-(ee.w/2))*son[v],ee.w,son[v]});
tot+=1ll*ee.w*son[v];
}
if(!flag) son[u]=1;
}
inline void prework()
{
scanf("%lld%lld",&n,&s);
for(int i=1;i<=n;i++)
e[i].clear();
ll u,v,w,c;
for(int i=1;i<=n-1;i++)
{
scanf("%lld%lld%lld%lld",&u,&v,&w,&c);
e[u].push_back(ed{v,w,c});
e[v].push_back(ed{u,w,c});
}
while(!q[1].empty()) q[1].pop();
while(!q[2].empty()) q[2].pop();
tot=0;dfs(1,0);
}
inline void mainwork()
{
edg d11,d12,nd1,d2,nd2;ans=0;
while(tot>s)
{
d11.sum=0;d12.sum=0;d2.sum=0;
if(!q[1].empty())
{
d11=q[1].top();q[1].pop();
d12.val=0;
if(!q[1].empty())
d12=q[1].top();
nd1.w=d11.w/2;nd1.sum=d11.sum;
nd1.val=1ll*(nd1.w-(nd1.w/2))*d11.sum;
if(nd1.val>d12.val)
d12=nd1;
}
if(!q[2].empty())
{
d2=q[2].top(),q[2].pop();
nd2.w=d2.w/2;nd2.sum=d2.sum;
nd2.val=1ll*(nd2.w-(nd2.w/2))*d2.sum;
}
if(tot-d11.val<=s && d11.sum>0)
{
++ans;
return;
}
if(d11.sum!=0 && d2.sum!=0)
{
if(d11.val+d12.val>=d2.val)
{
ans++;tot-=d11.val;
q[1].push(nd1);
q[2].push(d2);
}
else
{
ans+=2;tot-=d2.val;
q[1].push(d11);
q[2].push(nd2);
}
}
else if(d11.sum!=0 && d2.sum==0)
{
ans++;tot-=d11.val;
q[1].push(nd1);
}
else
{
ans+=2;tot-=d2.val;
q[2].push(nd2);
}
}
}
inline void print()
{
printf("%lld\n",ans);
}
int main()
{
int t=1;
scanf("%d",&t);
for(cas=1;cas<=t;cas++)
{
prework();
mainwork();
print();
}
return 0;
}