很容易想到对每个点考虑边合并它的子树边得到答案
显然不能用
01
t
r
i
e
01trie
01trie,因为每次查询要加一个
d
i
s
[
f
a
[
u
]
]
dis[fa[u]]
dis[fa[u]] 已经凉了
考虑启发式合并,每次暴力查询一个点在另一个集合中的最大值
按位贪心,假设当前已经选的权值为
x
x
x
如果当前位
i
i
i 为 1,那么如果选出来的数的范围在
[
x
,
x
+
2
i
)
[x,x+2^{i})
[x,x+2i) 就可以有贡献
如果当前位
i
i
i 为 0,那么如果选出来的数的范围在
[
x
+
2
i
,
x
+
2
i
+
1
)
[x+2^i,x+2^{i+1})
[x+2i,x+2i+1) 就可以用贡献
发现需要支持一个区间存不存在一个数,
s
e
t
set
set
l
o
w
e
r
b
o
u
n
d
lowerbound
lowerbound 即可
复杂度
O
(
n
l
o
g
(
n
)
2
l
o
g
(
∑
a
i
)
)
O(nlog(n)^2log(\sum a_i))
O(nlog(n)2log(∑ai))
也可以线段树合并优化一下暴力插的常数,但瓶颈任然在与查询的 3 个
l
o
g
log
log
#include<bits/stdc++.h>
#define cs const
using namespace std;
int read(){
int cnt = 0, f = 1; char ch = 0;
while(!isdigit(ch)){ ch = getchar(); if(ch == '-') f = -1; }
while(isdigit(ch)) cnt = cnt*10 + (ch-'0'), ch = getchar();
return cnt * f;
}
cs int N = 1e5 + 5;
int n, a[N], dis[N], fa[N];
int lg, mx;
vector<int> v[N];
void dfs(int u){
dis[u] += a[u]; mx = max(mx, dis[u]);
for(int i = 0; i < v[u].size(); i++){
int t = v[u][i]; fa[t] = u;
dis[t] += dis[u]; dfs(t);
}
}
int id[N], ans[N];
multiset<int> S[N];
typedef multiset<int>::iterator Int;
int Merge(int x, int y){
if(S[id[x]].size() > S[id[y]].size()) swap(id[x], id[y]);
if(S[id[x]].empty()) return -1;
int ans = 0;
for(Int it = S[id[x]].begin(); it != S[id[x]].end(); it++){
int ret = 0, nx = *it - dis[fa[y]];
for(int i = lg; i >= 0; i--){
int k = (nx >> i) & 1;
int L = ret + dis[fa[y]] + ((k^1) << i);
int R = L + (1 << i);
Int r = S[id[y]].lower_bound(L);
if(r == S[id[y]].end()){ ret += (k << i); continue; }
if(*r >= R){ ret += (k << i); continue; }
else ret += ((k^1) << i);
} ans = max(ans, ret ^ nx);
}
for(Int it = S[id[x]].begin(); it != S[id[x]].end(); it++) S[id[y]].insert(*it);
S[id[x]].clear();
return ans;
}
void calc(int u){
int ret = -1;
for(int i = 0; i < v[u].size(); i++){
int t = v[u][i];
calc(t); ret = max(ret, Merge(t, u));
} ans[u] = ret; S[id[u]].insert(dis[u]);
}
int main(){
n = read();
for(int i = 2; i <= n; i++){
int x = read(); v[x].push_back(i);
}
for(int i = 1; i <= n; i++) a[i] = read(), id[i] = i;
dfs(1);
lg = 1; while((1 << lg + 1) <= mx) ++lg;
calc(1);
for(int i = 1; i <= n; i++) cout << ans[i] << " ";
return 0;
}