题目大意
给定一棵 n 个点的树,除 1 外每个点有一只怪兽,打败它需要先消耗 ai 点 HP,再恢复 bi 点 HP。
求从 1 号点出发按照最优策略打败所有怪兽一开始所需的最少 HP。
(2≤n≤105)
(
2
≤
n
≤
10
5
)
解题思路
以 1 为根将树转化成有根树,那么每只怪兽要在父亲怪兽被击败后才能被击败。
考虑简化版问题:忽略父亲的限制,求最优的攻击顺序。
将怪兽分成两类:a < b 的和 a ≥ b 的,前一类打完会加血,后一类打完会扣血,显然最优策略下应该先打第一类再打第二类。
对于 a < b 的怪兽,显然最优策略下应该按照 a 从小到大打。
对于 a ≥ b 的怪兽,考虑两只怪兽 i, j,先打 i 再打 j 的过程中血量会减少到 HP + min(−ai, −ai + bi − aj),因为 a ≥ b,所以这等于 HP − ai − aj + bi。同理先打 j 再打 i 的过程中血量会减少到 HP − ai − aj + bj。
可以发现按照任何顺序都只和 b 有关,最优策略下需要让血量尽可能多,因此要按照 b 从大到小打。
如此可以 O(1) 比较任意两只怪兽应该先打谁,从而得到最优攻击顺序。
考虑原问题,求出忽略父亲限制后最优攻击顺序 p1 , p2 , …, pn。
若 p1 = 1,那么第一步打 p1 一定最优。
若 p1 ̸= 1,那么在打完 p1 的父亲 x 后,紧接着打 p1 一定最优。直接将 p1 和 x 两只怪兽合并,依然用 (a, b) 表示,并将 p1 所有儿子的父亲改为 x 即可消除 p1 的影响。
上面两步将规模为 n 的问题化成了规模为 n − 1 的问题,重复操作直到只剩下一只怪兽即可求出最少所需血量。
需要高效支持修改一个怪兽,删除最优怪兽,以及修改父亲
的操作。
对于最优怪兽,可以用堆维护;对于修改父亲,可以用并查集维护。
时间复杂度
O(nlogn)
O
(
n
l
o
g
n
)
。
by cls
代码
#include <bits/stdc++.h>
using namespace std;
#define mp make_pair
#define x first
#define y second
const int maxn=int(1e5)+111;
struct info {
long long a,b;
bool operator < (const info &rhs) const {
if(a<=b && rhs.a>rhs.b) return true;
if(a>b && rhs.a<rhs.b) return false;
if(a<=b && rhs.a<=rhs.b) return a<rhs.a;
if(a>b && rhs.a>rhs.b) return b>rhs.b;
return false;
}
info operator + (const info &rhs) const {
info R;
R.a=max(a,a-b+rhs.a);
R.b=R.a+(b+rhs.b)-(a+rhs.a);
return R;
}
}p[maxn];
typedef pair<info,int> sta;
int T,n;
vector<int> to[maxn];
int fa[maxn], f[maxn];
set<sta> st;
inline int find(const int &k) {
int p=k;
while(p^f[p]) p=f[p];
return f[k]=p;
}
void dfs(int u,int last) {
fa[u]=last;
for(int i=0;i<(int)to[u].size();++i) if(to[u][i]!=last)
dfs(to[u][i],u);
return;
}
void init() {
register int i;
p[1].a=p[1].b=0;
st.clear();
for(i=1;i<=n;++i) {
to[i].clear();
f[i]=i;
}
return;
}
void work() {
scanf("%d",&n);
init();
register int i,u,v;
for(i=2;i<=n;++i) {
scanf("%I64d%I64d",&p[i].a,&p[i].b);
st.insert(mp(p[i],i));
}
for(i=2;i<=n;++i) {
scanf("%d%d",&u,&v);
to[u].push_back(v), to[v].push_back(u);
}
dfs(1,-1);
while(st.size()) {
u=st.begin()->y; st.erase(st.begin());
v=find(fa[u]);
f[u]=v;
if(v>1) st.erase(mp(p[v],v));
p[v]=p[v]+p[u];
if(v>1) st.insert(mp(p[v],v));
}
printf("%I64d\n",p[1].a);
return;
}
int main() {
#ifndef ONLINE_JUDGE
freopen("H.in","r",stdin);
freopen("output.txt","w",stdout);
#endif
for(scanf("%d",&T);T;T--)
work();
return 0;
}