与E1的不同在于对不同的边进行操作的花费数,有1和2两种。
如果按照E1的类似方法做,貌似挺麻烦的,要用好几个堆来维护。
先来考虑一下最后的答案是怎么产生的。
很明显,是对操作一次花费为1的边进行了a次操作,对操作一次花费为2的边进行了b次操作。而a和b我们都不知道。不过可以肯定的是这a次一定是要对每次操作减去的贡献最多的边进行,b次也是如此。
然后我们可以想到,把进行一次操作花费为1的边和花费为2的边分成两部分。我们枚举一次花费为1的边要操作几次,然后对于一次操作花费为2的边操作几次,二分得到答案。一眼看时间复杂度是:O(n log n)。(因为边的数量为n嘛)
然后发现,由于w[i]<=1e16,所以每一条边最多会操作54次,那么边最多就有540万条,O(54n log 54n) ,也OK。
如果在写E1的时候,不是用堆来处理,而是直接想到把每一条边删除几次后产生的若干条子边都处理出来放一起,排序后直接O(n)往后扫,那么就会大大缩短这两题的代码实现时间。
姬小路秋子还是强!!!
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int T,n,s,aa,bb,sum,l,r,mid,hh,ans;
int d[N],size[N],a[54*N],b[54*N];
struct number{int x,y,w,opt;}num[N];
int cnt,head[N];
struct edge{int next,to,w;}e[N<<1];
inline void add(int u,int v,int w)
{
cnt++;
e[cnt].next=head[u];
e[cnt].to=v;
e[cnt].w=w;
head[u]=cnt;
}
void dfs(int u,int fa)
{
bool jay=true;
for (register int i=head[u]; i; i=e[i].next)
if (e[i].to!=fa)
{
jay=false;
d[e[i].to]=d[u]+1;
dfs(e[i].to,u);
size[u]+=size[e[i].to];
}
if (jay) size[u]=1;
}
signed main(){
scanf("%lld",&T);
while (T--)
{
scanf("%lld%lld",&n,&s);
cnt=0;
for (register int i=1; i<=n; ++i) head[i]=0;
d[1]=0;
for (register int i=1; i<=n; ++i) size[i]=0;
for (register int i=1; i<n; ++i)
{
scanf("%lld%lld%lld%lld",&num[i].x,&num[i].y,&num[i].w,&num[i].opt);
add(num[i].x,num[i].y,num[i].w);
add(num[i].y,num[i].x,num[i].w);
}
dfs(1,0);
aa=bb=0;
sum=0;
for (register int i=1; i<n; ++i)
{
if (d[num[i].x]>d[num[i].y]) swap(num[i].x,num[i].y);
sum+=size[num[i].y]*num[i].w;
if (num[i].opt==1)
{
while (num[i].w)
{
int del=num[i].w-num[i].w/2;
a[++aa]=del*size[num[i].y];
num[i].w/=2;
}
}
else
{
while (num[i].w)
{
int del=num[i].w-num[i].w/2;
b[++bb]=del*size[num[i].y];
num[i].w/=2;
}
}
}
sort(a+1,a+aa+1); reverse(a+1,a+aa+1);
sort(b+1,b+bb+1); reverse(b+1,b+bb+1);
for (register int i=1; i<=aa; ++i) a[i]+=a[i-1];
for (register int i=1; i<=bb; ++i) b[i]+=b[i-1];
ans=1e18;
for (register int i=0; i<=aa; ++i)
{
l=0; r=bb; hh=-1;
while (l<=r)
{
mid=l+r>>1;
if (sum-(a[i]+b[mid])<=s) hh=mid,r=mid-1;
else l=mid+1;
}
if (~hh) ans=min(ans,i+hh*2);
}
printf("%lld\n",ans);
}
return 0;
}