HDU 6133 Army Formations
dsu on tree
题意
给你一棵n个节点的二叉树,每个节点有一个提交任务的时间,每个节点总的提交任务的罚时为:提交这个节点和其子树所有的任务,每个任务提交时间的总和为该点的罚时。求每个节点提交完所有任务的最小罚时
思路
树上启发式合并。
考虑有这样一个数据结构, 可以动态往一个多重集里面添加数字, 删除数字, 并查询多重集中元素的sumofsum.
这个非常简单, 用树状数组就好了.
然后我们考虑这样一个启发式过程. 定义一个dfs(root)
, root表示当前子树的根, 用left_son, right_son表示两个儿子, 且左子树的大小比右子树小.
def tree_remove(root):
for x in tree(root):
remove_from_multiset(x)
def tree_add(root):
for x in tree(root):
add_into_multiset(x)
def dfs(root):
dfs(left_son)
tree_remove(left_son)
dfs(right_son)
tree_add(left_son)
add_into_multiset(root)
f[root] = ask_sumofsum()
其中 add_into_multiset
, remove_from_multiset
, ask_sumofsum
为上面说到的数据结构支持的操作: 添加数字, 删除数字, 查询sumofsum.
而tree(root)表示以root为根的子树中的节点. tree_remove(root)和tree_add(root)其实就是遍历以root为根的子树中的节点, 并把他们加到多重集或者从多重集中删除.
复杂度就是, 每个节点被 remove_from_multiset 的次数是 O(logn) 的, 因为每次被 remove 都是因为这个节点处于一个较小的子树中. 这样 remove_from_multiset 和 add_into_multiset 的执行次数都是 O(logn) , 然后考虑树状数组的复杂度, 总复杂度就是 O(nlog2n)
代码
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
#define lson l,mid,rt<<1
#define rson mid+1,r,rt<<1|1
#define M(a,b) memset(a,b,sizeof(a))
using namespace std;
const int MAXN=100007;
const int oo=0x3f3f3f3f;
typedef long long LL;
int l[MAXN], r[MAXN], sz[MAXN];
struct Edge
{
int to, ne;
}e[2*MAXN];
int head[MAXN], edgenum;
void addedge(int u, int v)
{
e[edgenum].to=v, e[edgenum].ne=head[u];head[u]=edgenum++;
e[edgenum].to=u, e[edgenum].ne=head[v];head[v]=edgenum++;
}
void dfs1(int u, int fa)//???????sz
{
sz[u]=1;
for(int i=head[u];~i;i=e[i].ne)
{
int to=e[i].to;
if(to==fa) continue;
dfs1(to, u);
sz[u]+=sz[to];
if(r[u]==0) r[u]=to;
else if(sz[to]>sz[r[u]])
l[u]=r[u], r[u]=to;
else l[u]=to;
}
}
LL sum[MAXN<<2], num[MAXN<<2];
int lowbit(int x) { return x&(-x); }
void update(LL f[], int pos, int val, int tot)
{
while(pos<=tot)
{
f[pos]+=val;
pos+=lowbit(pos);
}
}
LL query(LL f[], int pos)
{
LL ans=0;
while(pos)
{
ans+=f[pos];
pos-=lowbit(pos);
}
return ans;
}
int v[MAXN], ha[MAXN], hav[MAXN];
LL vsum[MAXN];
LL res=0;
void tree_add(int rt, int n)
{
if(rt==0) return;
res+=(query(num, n)-query(num, ha[rt]))*v[rt];
res+=query(sum, ha[rt])+v[rt];
update(num, ha[rt], 1, n);
update(sum, ha[rt], v[rt], n);
tree_add(l[rt], n), tree_add(r[rt], n);
}
void tree_rem(int rt, int n)
{
if(rt==0) return;
res-=(query(num, n)-query(num, ha[rt]))*v[rt];
res-=query(sum, ha[rt]);
update(num, ha[rt], -1, n);
update(sum, ha[rt], -v[rt], n);
tree_rem(l[rt], n), tree_rem(r[rt], n);
}
void dfs2(int n, int u)
{
if(u==0) return;
if(l[u]!=0) dfs2(n, l[u]);
if(l[u]!=0) tree_rem(l[u], n);
if(r[u]!=0) dfs2(n, r[u]);
if(l[u]!=0) tree_add(l[u], n);
res+=(query(num, n)-query(num, ha[u]))*v[u];
res+=query(sum, ha[u])+v[u];
update(num, ha[u], 1, n);
update(sum, ha[u], v[u], n);
vsum[u]=res;
}
int main()
{
int T;scanf("%d", &T);
while(T--)
{
M(l, 0), M(r, 0), M(sz, 0);
M(head, -1);edgenum=1;
M(ha, 0), M(v, 0), M(vsum, 0);
int n;scanf("%d", &n);
for(int i=1;i<=n;i++)
{
scanf("%d", &v[i]);
hav[i]=v[i];
}
sort(hav+1, hav+1+n);
int han=unique(hav+1, hav+n+1)-hav-1;
for(int i=1;i<=n;i++)
ha[i]=lower_bound(hav+1, hav+1+han, v[i])-hav;
for(int i=1;i<n;i++)
{
int u, v;scanf("%d%d", &u, &v);
addedge(u, v);
}
dfs1(1, 0);
M(num, 0), M(sum, 0);res=0;
dfs2(han, 1);
for(int i=1;i<=n;i++) printf("%lld ", vsum[i]);
printf("\n");
}
return 0;
}