题目大意:定义 p ( x , k ) : p(x,k): p(x,k):x 为根结点,深度小于等于k的所有结点个数。再定义一个集合S的运算 f ( S ) : f(S): f(S):集合内元素两两异或值的平方和,例如 S = {1,1,2,3},则 f(S) = ( 1 ⨁ 1 ) 2 (1 \bigoplus1)^2 (1⨁1)2 + ( 1 ⨁ 2 ) 2 (1 \bigoplus2)^2 (1⨁2)2 + ( 1 ⨁ 3 ) 2 (1 \bigoplus3)^2 (1⨁3)2 + ( 1 ⨁ 2 ) 2 (1 \bigoplus2)^2 (1⨁2)2 + ( 1 ⨁ 3 ) 2 (1 \bigoplus3)^2 (1⨁3)2 + ( 2 ⨁ 3 ) 2 (2 \bigoplus3)^2 (2⨁3)2
输入 n,k,给一棵n个结点的有根树,1为根结点,对树上所有结点 x,输出所有的 f ( p ( x , k ) ) f(p(x,k)) f(p(x,k))。
根据式子很容易想到按位的贡献来做,单独考虑每一位的贡献,如果两个数异或起来有三个1,例如a,b,c位为 1,那么这一对数的贡献是 ( 2 a + 2 b + 2 c ) 2 (2^a + 2^b + 2^c)^2 (2a+2b+2c)2,简记为 ( a + b + c ) 2 (a + b + c)^2 (a+b+c)2。
拆开后可以得到 a 2 + b 2 + c 2 + 2 a b + 2 a c + 2 b c a^2 + b^2 + c^2 + 2ab + 2ac + 2bc a2+b2+c2+2ab+2ac+2bc
一个做法是:
第一次先计算 a 2 + b 2 + c 2 a^2 + b^2 + c^2 a2+b2+c2,即单个位为1的贡献,这很容易计算,只要集合内考虑某一位为1的数字个数乘上这一位为0的数字个数再乘上这一位贡献的平方
对于 2 a b + 2 b c + 2 a c 2ab + 2bc + 2ac 2ab+2bc+2ac,需要考虑两个位,经过第一步的计算之后,第二步计算需要补加贡献:枚举两个位x,y,这两个数异或起来这两个位同时为1,那么必然漏算了 2 x y 2xy 2xy 这个贡献,统计有多少对数字或起来这两异同时为 1,然后把漏掉的贡献加上。
对于在有根树上统计某一位或某两位为1并且与深度有关的结点数,显然可以dsu on tree 或长链剖分,由于需要枚举二进制位pair,用dsu on tree 复杂度为达到 n log 3 ( n ) n \log^3(n) nlog3(n),即使4s也可能会T飞
上长链剖分,需要维护一个形如 dp[i][j]:表示 i 为根结点,深度 ≤ \leq ≤ j 且某一位或某两位满足要求的结点个数,直接维护这个前缀和或统计这个前缀和复杂度都会达到 n 2 n^2 n2,因为必然会枚举到重儿子的链长,而长链剖分的复杂度保证就是不枚举重儿子的链长。
做法是改为维护后缀和,长链剖分可以直接维护并转移,需要的数据减一下即可。
然后再上一个更加弱智的写法,常数大到差点过不去
代码:
#include<bits/stdc++.h>
using namespace std;
const int maxn = 2e5 + 10;
vector<int> g[maxn];
typedef unsigned long long ll;
ll tmp[maxn],*id1,*dp[maxn];
ll tmp2[maxn],*id2,*tp[maxn];
ll val[maxn];
int len[maxn],son[maxn],n,k;
ll ans[maxn];
void prework(int u,int fa) {
len[u] = 0;son[u] = 0;
for(int i = 0; i < g[u].size(); i++) {
int it = g[u][i];
if(it == fa) continue;
prework(it,u);
if(son[u] == 0 || len[son[u]] < len[it])
son[u] = it;
}
len[u] = len[son[u]] + 1;
}
void dfs(int u,int fa,int p) {
dp[u][0] = ((val[u] >> p) & 1);
tp[u][0] = !dp[u][0];
if(son[u]) {
tp[son[u]] = tp[u] + 1;
dp[son[u]] = dp[u] + 1;
dfs(son[u],u,p);
}
for(int i = 0; i < g[u].size(); i++) {
int it = g[u][i];
if(it == fa || it == son[u]) continue;
dp[it] = id1; id1 += len[it];
tp[it] = id2; id2 += len[it];
dfs(it,u,p);
for(int i = 0; i < len[it]; i++) {
tp[u][i + 1] += tp[it][i];
dp[u][i + 1] += dp[it][i];
}
}
if(len[u] - 1 >= 1) {
dp[u][0] += dp[u][1];
tp[u][0] += tp[u][1];
}
ll a = dp[u][0],b = tp[u][0];
if(k + 1 <= len[u] - 1) {
a -= dp[u][k + 1];
b -= tp[u][k + 1];
}
ans[u] += a * b * (1llu << p) * (1llu << p);
}
void dfs2(int u,int fa,int p,int q) {
dp[u][0] = ((val[u] >> p) & 1) && (val[u] >> q & 1);
tp[u][0] = !((val[u] >> p) & 1) && !(val[u] >> q & 1);
if(son[u]) {
tp[son[u]] = tp[u] + 1;
dp[son[u]] = dp[u] + 1;
dfs2(son[u],u,p,q);
}
for(int i = 0; i < g[u].size(); i++) {
int it = g[u][i];
if(it == fa || it == son[u]) continue;
dp[it] = id1; id1 += len[it];
tp[it] = id2; id2 += len[it];
dfs2(it,u,p,q);
for(int i = 0; i < len[it]; i++) {
tp[u][i + 1] += tp[it][i];
dp[u][i + 1] += dp[it][i];
}
}
if(len[u] - 1 >= 1) {
dp[u][0] += dp[u][1];
tp[u][0] += tp[u][1];
}
ll a = dp[u][0],b = tp[u][0];
if(k + 1 <= len[u] - 1) {
a -= dp[u][k + 1];
b -= tp[u][k + 1];
}
ans[u] += 2 * a * b * (1llu << p) * (1llu << q);
}
void dfs3(int u,int fa,int p,int q) {
dp[u][0] = (((val[u] >> p) & 1) == 1) && ((val[u] >> q & 1) == 0);
tp[u][0] = (((val[u] >> p) & 1) == 0) && ((val[u] >> q & 1) == 1);
if(son[u]) {
tp[son[u]] = tp[u] + 1;
dp[son[u]] = dp[u] + 1;
dfs3(son[u],u,p,q);
}
for(int i = 0; i < g[u].size(); i++) {
int it = g[u][i];
if(it == fa || it == son[u]) continue;
dp[it] = id1; id1 += len[it];
tp[it] = id2; id2 += len[it];
dfs3(it,u,p,q);
for(int i = 0; i < len[it]; i++) {
tp[u][i + 1] += tp[it][i];
dp[u][i + 1] += dp[it][i];
}
}
if(len[u] - 1 >= 1) {
dp[u][0] += dp[u][1];
tp[u][0] += tp[u][1];
}
ll a = dp[u][0],b = tp[u][0];
if(k + 1 <= len[u] - 1) {
a -= dp[u][k + 1];
b -= tp[u][k + 1];
}
ans[u] += 2 * a * b * (1llu << p) * (1llu << q);
}
int main() {
scanf("%d%d",&n,&k);
for(int i = 1; i <= n; i++) {
scanf("%llu",&val[i]);
}
for(int i = 2,f; i <= n; i++) {
scanf("%d",&f);
g[f].push_back(i);
g[i].push_back(f);
}
prework(1,0);
for(int j = 0; j <= 29; j++) {
id1 = tmp,dp[1] = id1,id1 += len[1];
id2 = tmp2,tp[1] = id2,id2 += len[1];
dfs(1,0,j);
}
for(int i = 0; i <= 29; i++) {
for(int j = i + 1; j <= 29; j++) {
id1 = tmp,dp[1] = id1,id1 += len[1];
id2 = tmp2,tp[1] = id2,id2 += len[1];
dfs2(1,0,i,j);
}
}
for(int i = 0; i <= 29; i++) {
for(int j = i + 1; j <= 29; j++) {
id1 = tmp,dp[1] = id1,id1 += len[1];
id2 = tmp2,tp[1] = id2,id2 += len[1];
dfs3(1,0,i,j);
}
}
for(int i = 1; i <= n; i++) {
printf("%llu\n",ans[i]);
}
return 0;
}