F. Strange Memory
题目大意:
有一个根为1的树,每个点都有一个点权ai,求
∑
i
=
1
n
∑
j
=
i
+
1
n
[
a
i
x
o
r
a
j
=
=
a
l
c
a
(
i
,
j
)
]
∗
(
i
x
o
r
j
)
x
o
r
指
异
或
\sum_{i = 1}^{n}\sum_{j = i+1}^{n}[a_i xor a_j == a_{lca(i, j)}]*(i xor j)\\xor指异或
i=1∑nj=i+1∑n[aixoraj==alca(i,j)]∗(ixorj)xor指异或
思路:
1、只有对子树的询问
2、没有修改
很容易想到要用dsu on tree(树上启发式合并)
枚举每一个lca,计算其权值,在统计当前节点的答案时,暴力计算所有轻儿子的贡献,每次保留重儿子的贡献不擦除,擦除轻儿子的贡献
而且我们可以发现 对于 一个lca , 其一个子节点 i, 我们只需要找值为a[lca] ^ a[i]的所有下标统计答案就好了
然后就可以得到一个最初的算法
void dsu(int u, int father, int mark){//mark表示当前贡献是否要擦除
for(int i = head[u]; i; i = e[i].next){//暴力统计轻儿子的贡献
int v = e[i].to;
if(v == father || v == son[u])continue;
dsu(v, u, 0);
}
if(son[u])dsu(son[u], u, 1), vis[son[u]] = 1;//计算重儿子,并标记,当前已经计算过该重儿子,因为不会擦除,所以要标记下来,下次不计算
update(u, 1);//添加当前节点的贡献(这里的贡献不是对答案的贡献!)
for(int i = head[u]; i; i = e[i].next){//枚举以当前节点为lca的答案
int v = e[i].to;
if(v == father || vis[v])continue;
calc(v, u, u);//计算其子节点,以u为lca的答案
for(auto x : q){//将其子节点遍历的所有节点添加
update(x, 1);
}
q.clear();//注意每次要清空,且不能边遍历边添加,防止lca不是u
}
if(son[u])vis[son[u]] = 0;//清除标记
if(!mark)del(u, father);//清除贡献
}
但是calc 和update该怎么写呢
一般来讲,第一时间想到的肯定是,用一个set数组来存值为ai 的所有下标
然后每次统计答案的时候,计算所有ai ^ alca 的所有下标的值和ai带来的对答案的贡献
calc ://这只是伪代码QWQ
set p[val] 存的是值为 val 的所有下标
for(auto x : p[a[i] ^ a[lca]){//当前点i
ans += i ^ x;
}
update就是标记一下就好
但是这样的复杂度真的够吗?
很容易想到这样处理的复杂度是
O
(
k
n
l
o
g
n
)
O(knlogn)
O(knlogn)
这里的k是对于每个节点满足条件的下标个数,在所有点权都相同时,将会直接退化成
O
(
n
2
l
o
g
n
)
O(n^2logn)
O(n2logn)
和暴力(枚举两个点,再求LCA,计算答案)的时间复杂度相同
我们得想办法优化一下!
然后就是重头戏了,将下标拆位!
什么意思呢?
就是将下标拆分成二进制的0和1
假设下标为 5 ,其二进制是101
我们就将下标拆成 1 , 0, 1
然后分别统计每一位所带来的的 对答案的贡献
因为下标最多1e5
二进制位最多16位
所以时间复杂度就变成了
O
(
16
n
l
o
g
n
)
O(16nlogn)
O(16nlogn)
过1e5完全够了
AC代码:
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5 + 100;
const int maxm = 1e6 + 1e5;
struct e_node{
int next;
int to;
}e[maxn << 1];
int head[maxn], a[maxn], cnt[maxm][23][2];
//cnt[val][i][0/1]表示,值为val时下标二进制的所有数第i位为 0/1 的数量
int tot;
void add_e(int u, int v){
e[++tot].to = v;
e[tot].next = head[u];
head[u] = tot;
}
int son[maxn], sz[maxn];
void pre_dfs(int u, int father){
sz[u] = 1;
for(int i = head[u]; i; i = e[i].next){
int v = e[i].to;
if(v == father)continue;
pre_dfs(v, u);
sz[u] += sz[v];
if(sz[son[u]] < sz[v])son[u] = v;
}
}
vector<int > q;
int vis[maxn];
ll ans;
void update(int u, int val){
for(int i = 0; i <= 16; ++i)
cnt[a[u]][i][(u>>i)&1] += val;
}
void calc(int u, int father, int lca){
q.push_back(u);
for(int i = 0; i <= 16; ++i){
ans += cnt[a[u] ^ a[lca]][i][!((u>>i)&1)]*(1<<i);
}
for(int i = head[u]; i; i = e[i].next){
int v = e[i].to;
if(v == father)continue;
calc(v, u, lca);
}
}
void del(int u, int father){//删除贡献
update(u, -1);
for(int i = head[u]; i; i = e[i].next){
int v = e[i].to;
if(v == father || vis[v])continue;
del(v, u);
}
}
void dsu(int u, int father, int mark){//mark表示当前贡献是否要擦除
for(int i = head[u]; i; i = e[i].next){//暴力统计轻儿子的贡献
int v = e[i].to;
if(v == father || v == son[u])continue;
dsu(v, u, 0);
}
if(son[u])dsu(son[u], u, 1), vis[son[u]] = 1;//计算重儿子,并标记,当前已经计算过该重儿子,因为不会擦除,所以要标记下来,下次不计算
update(u, 1);//添加当前节点的贡献(这里的贡献不是对答案的贡献!)
for(int i = head[u]; i; i = e[i].next){//枚举以当前节点为lca的答案
int v = e[i].to;
if(v == father || vis[v])continue;
calc(v, u, u);//计算其子节点,以u为lca的答案
for(auto x : q){//将其子节点遍历的所有节点添加
update(x, 1);
}
q.clear();//注意每次要清空,且不能边遍历边添加,防止lca不是u
}
if(son[u])vis[son[u]] = 0;//清除标记
if(!mark)del(u, father);//清除贡献
}
int main(){
int n;
scanf("%d", &n);
for(int i = 1; i <= n; ++i){
scanf("%d", &a[i]);
}
for(int i = 1; i <= n - 1; ++i){
int u, v;
scanf("%d %d", &u, &v);
add_e(u, v);
add_e(v, u);
}
pre_dfs(1, 0);
dsu(1, 0, 1);
printf("%lld", ans);
}