题目链接:Army Formations
题意:一棵二叉树,每一个节点有一个信息
ai
,每发送一个信息需要的时间是当前时间
t
加上这个信息的权值
题解:显然,如果我们把一个权值大的放在前面,所有在这个信息后面的信息所需要的时间都会增加,于是我们贪心的发送,即按照权值从小到大发送就是最优方法。那么问题就变成了,我们如何算出子树内每一个节点子树和排序后的前缀和的和。由于我们需要快速合并两棵子树的权值,很容易想到用树状数组即可,同时维护一个current_ans来记录当前的答案。整个操作流程就是:
def tree_remove(root):
for x in tree(root):
remove_from_multiset(x)
def tree_add(root):
for x in tree(root):
add_into_multiset(x)
def dfs(root):
dfs(left_son)
tree_remove(left_son)
dfs(right_son)
tree_add(left_son)
add_into_multiset(root)
f[root] = ask_sumofsum()
以上代码来自官方题解。
为了保证合并的快速,所以我们需要启发式合并,即把小的子树合并到大的子树中,合并过程中每一个元素最多被合并
log
次,而树状数组插入的时间复杂度是
O(log n)
,于是整个算法的时间复杂度就是
O(nlog2n)
。
注意由于要大量调用递归函数,所以一个参数会比两个参数的函数快很多,我被这个地方卡了一上午= =
#include <iostream>
#include <stdio.h>
#include <vector>
#include <algorithm>
using namespace std;
const int N = 100005;
long long c1[N],c2[N],ans[N],t;
int b[N],v[N],id[N],sz[N],ls[N],rs[N],ed;
vector<int>e[N];
void Add(long long* wh,int i,int d){
for(;i<N;i+=i&-i)
wh[i]+=d;
}
long long Sum(long long* wh,int i){
long long res=0;
for(;0<i;i-=i&-i)
res+=wh[i];
return res;
}
int dfs1(int last,int now){
sz[now]=1;
for(int i=0;i<e[now].size();i++)
if(e[now][i]!=last){
if(ls[now]==-1)
ls[now]=e[now][i];
else if(rs[now]==-1)
rs[now]=e[now][i];
sz[now]+=dfs1(now,e[now][i]);
}
if(rs[now]!=-1&&sz[ls[now]]>sz[rs[now]])
swap(ls[now],rs[now]);
return sz[now];
}
void Plu(int x){
t+=(Sum(c1,ed)-Sum(c1,id[x]-1)+1)*v[x]+Sum(c2,id[x]-1);
Add(c1,id[x],1);
Add(c2,id[x],v[x]);
}
void Miu(int x){
t-=(Sum(c1,ed)-Sum(c1,id[x]-1))*v[x]+Sum(c2,id[x]-1);
Add(c1,id[x],-1);
Add(c2,id[x],-v[x]);
}
void Pls(int now){
Plu(now);
if(ls[now]!=-1)
Pls(ls[now]);
if(rs[now]!=-1)
Pls(rs[now]);
}
void Sub(int now){
Miu(now);
if(ls[now]!=-1)
Sub(ls[now]);
if(rs[now]!=-1)
Sub(rs[now]);
}
void dfs2(int now){
if(ls[now]!=-1)
dfs2(ls[now]);
if(rs[now]!=-1){
Sub(ls[now]);
dfs2(rs[now]);
Pls(ls[now]);
}
Plu(now);
ans[now]=(ls[now]==-1&&rs[now]==-1?v[now]:t);
}
int main(){
int T;
scanf("%d",&T);
while(T--){
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++){
e[i].clear();
ls[i]=rs[i]=-1;
scanf("%d",&v[i]);
b[i]=v[i];
}
sort(b+1,b+1+n);
ed=unique(b+1,b+1+n)-(b+1);
for(int i=1;i<=ed;i++)
c1[i]=c2[i]=0;
for(int i=1;i<=n;i++)
id[i]=lower_bound(b+1,b+ed,v[i])-b;
for(int i=1,x,y;i<n;i++){
scanf("%d %d",&x,&y);
e[x].push_back(y);
e[y].push_back(x);
}
dfs1(0,1);
t=0;
dfs2(1);
for(int i=1;i<=n;i++)
printf("%lld ",ans[i]);
puts("");
}
return 0;
}