链接:https://www.nowcoder.com/acm/contest/200/F
时间限制:C/C++ 2秒,其他语言4秒
空间限制:C/C++ 204800K,其他语言409600K
64bit IO Format: %lld
题目描述
事实证明,小可爱是最可爱的嘤! ——佚名
现在小可爱在颓游戏,但是他遇到了一个问题:
小可爱率领的部队现在面对的是敌军在这一地区的驻军,敌国战争机器的运作很大程度上依赖指挥,所以敌军内部是严明分级的,就是说,全部敌军可以看作一棵树,每只敌军部队(树上每个节点)有其战斗力。你可以对任意敌军部队发动进攻,小可爱的部队有战斗力p,意味着他的每次进攻将使得被进攻的这支部队的战斗力减少p,对上级指挥系统的打击同时会影响其下级部队。具体来说,当他对点i发动进攻,部队i的战力减少p的同时,对于其子树内点j,部队j的战力减少
M
a
x
(
0
,
p
−
d
i
s
(
i
,
j
)
2
)
Max(0,p−dis(i,j)^2)
Max(0,p−dis(i,j)2)(dis(i,j)表示点i,j间简单路径的长度)。如果某支部队战力小于0,那么这支部队就被消灭了,一支部队被消灭不会改变敌军编制(即这棵树的结构不会改变)。
小可爱想知道,你的部队最少发动几次进攻,才能全歼敌军
由于小可爱还要爆手速发展自己实力,所以把这个问题交给了你。
输入描述:
第一行两个正整数n,p,分别表示德军部队数目和你部战斗力
第二行n个正整数,表示德军各部战斗力 m i m_i mi
第三行到第n+1行,每行两个正整数i,j,表示i,j两支部队存在从属关系(i为j的上级)
输出描述:
输出一个整数,表示最少进攻次数
示例1
输入
7 3
1 1 3 7 5 3 3
1 2
2 3
1 4
2 5
4 6
1 7
输出
8
说明
对一号、七号部队各发动一次进攻,对三号、四号、五号部队各发动两次进攻
备注:
ps:
1,敌军以一号部队为司令部,即这棵树的根节点为1
2,你不能对已经被消灭的部队发动进攻
提示:
1,本题数据范围差距较大,分层明显
2,本题输入数据较大,请不要使用cin
对于100%的数据,n<=1000000;p,mi<=1e9
思路:很明显,贪心从根节点开始进攻,然后再向儿子进攻是最优的。
那么对于节点k,如何计算其祖先节点对k点的影响呢?
设d[k]表示k在树中的深度,则
p
−
d
i
s
(
i
,
j
)
2
=
p
−
(
d
[
i
]
−
d
[
j
]
)
2
p-dis(i,j)^2=p-(d[i]-d[j])^2
p−dis(i,j)2=p−(d[i]−d[j])2
p
−
d
i
s
(
i
,
j
)
2
=
p
−
(
d
[
i
]
2
−
2
∗
d
[
i
]
∗
d
[
j
]
+
d
[
j
]
2
)
p-dis(i,j)^2=p-(d[i]^2-2*d[i]*d[j]+d[j]^2)
p−dis(i,j)2=p−(d[i]2−2∗d[i]∗d[j]+d[j]2)
因为 p p p和 d [ i ] d[i] d[i]已知,所以我们只需要记录 ∑ d [ j ] \sum d[j] ∑d[j]和 ∑ d [ j ] 2 \sum d[j]^2 ∑d[j]2以及相应的个数,即可求出祖先对当前节点的影响。
#include<bits/stdc++.h>
using namespace std;
const int MAX=1e6+10;
const int MOD=1e9+7;
const double PI=acos(-1.0);
typedef long long ll;
ll a[MAX],d[MAX];
ll p;
vector<int>e[MAX];
int nex[MAX];
ll ans=0;
void dfs(int k,ll sum,ll sqsum,ll cnt,ll dep,int now)
{
d[k]=dep;
ll num;
while(nex[now]!=k&&p<=(d[nex[now]]-d[k])*(d[nex[now]]-d[k]))
{
now=nex[now];
num=a[now]/p+1;
sum-=num*d[now];
sqsum-=num*d[now]*d[now];
cnt-=num;
}
a[k]-=cnt*p-(cnt*d[k]*d[k]-2*sum*d[k]+sqsum);
if(a[k]<0)a[k]=-p;
num=a[k]/p+1;
for(int i=0;i<e[k].size();i++)
{
nex[k]=e[k][i];
dfs(e[k][i],sum+num*d[k],sqsum+num*d[k]*d[k],cnt+num,dep+1,now);
}
ans+=num;
}
int main()
{
int n;
cin>>n>>p;
for(int i=1;i<=n;i++)scanf("%lld",&a[i]);
for(int i=1;i<n;i++)
{
int x,y;
scanf("%d%d",&x,&y);
e[x].push_back(y);
}
nex[0]=1;
dfs(1,0,0,0,0,0);
cout<<ans<<endl;
return 0;
}