题目描述
题解
因为每走一个节点就有可能会减少一部分的蚂蚁,我们可以dfs统计一下从当前节点走到食蚁兽呆的那条边就损失多少倍蚂蚁,也就是经过的各个节点的度数-1的乘积。可以发现如果这个答案
>109
的话是没有意义的,可以直接舍掉。得出了减少多少倍x之后,将m排序,由于m/x的值一定单调,在m上二分找出mi/x=k的区间左端点l和右端点r,即从这个叶子节点出发会有k*(r-l+1)个蚂蚁被吃掉。
我刚开始有一个和这个很相似但是错误的思路:统计出来需要损失x倍的点有多少个,然后对于每一个m二分然后用前缀和求数目。但是这个x的范围是爆炸的,就算
109
以上可以舍掉也没法开数组。所以还是对于每一个x二分m比较科学。
但是还有一点问题是题目中说的是下取整,也就是说,每一次除再下取整不一定等于先全都除完再下取整。那么我们现在就要证明这个东西:
⌊⌊ni⌋j⌋=⌊nij⌋
给出一个不是很靠谱的证明:
令 p=⌊ni⌋ ,再设 ip+k=n,⌊ki⌋=0
那么原式就变成了
⌊pj⌋=⌊nij⌋
⌊nij⌋=⌊pi+kij⌋=⌊piij+kij⌋
显然 kij 不为整数,即 ⌊pj⌋=⌊nij⌋ .
证毕。
代码
#include<algorithm>
#include<iostream>
#include<cstring>
#include<cstdio>
using namespace std;
#define N 1000005
#define LL long long
int n,g,x,y,sx,sy;
int tot,point[N],nxt[N*2],v[N*2],du[N];
LL k,d[N],m[N],cnt,ans;
LL read()
{
LL x=0;char ch=getchar();
while (ch<'0'||ch>'9') ch=getchar();
while (ch>='0'&&ch<='9') x=x*10+ch-'0',ch=getchar();
return x;
}
void addedge(int x,int y)
{
++tot; nxt[tot]=point[x]; point[x]=tot; v[tot]=y;
++tot; nxt[tot]=point[y]; point[y]=tot; v[tot]=x;
}
int findl(LL x)
{
if (!x) return 0;
int l=1,r=g,mid,ans=0;
while (l<=r)
{
mid=(l+r)>>1;
if (m[mid]/x==k) ans=mid,r=mid-1;
else if (m[mid]/x<k) l=mid+1;
else r=mid-1;
}
return ans;
}
int findr(LL x)
{
if (!x) return 0;
int l=1,r=g,mid,ans=0;
while (l<=r)
{
mid=(l+r)>>1;
if (m[mid]/x==k) ans=mid,l=mid+1;
else if (m[mid]/x>k) r=mid-1;
else l=mid+1;
}
return ans;
}
void dfs(int x,int fa)
{
bool flag=false;
for (int i=point[x];i;i=nxt[i])
if (v[i]!=fa)
{
flag=true;
d[v[i]]=d[x]*((LL)du[x]-1);
if (d[v[i]]>1000000000) d[v[i]]=0;
dfs(v[i],x);
}
if (!flag)
{
int L=findl(d[x]);
int R=findr(d[x]);
if (!L||!R) return;
cnt+=(LL)R-(LL)L+1;
}
}
int main()
{
n=(int)read(); g=(int)read(); k=read();
for (int i=1;i<=g;++i) m[i]=read();
sort(m+1,m+g+1);
for (int i=1;i<n;++i)
{
x=(int)read();y=(int)read();
addedge(x,y);
++du[x];++du[y];
if (i==1) sx=x,sy=y;
}
d[sx]=d[sy]=1;
dfs(sx,sy);
dfs(sy,sx);
ans=cnt*k;
printf("%lld\n",ans);
}