题目核心问题大意:
给一棵树,问你从K节点,找若干个他的祖先。使得相邻2个元素(一个为另一个最近的祖先)的XOR/AND/OR的值算出来, 然后求和,使得总值最大。
直接做tree dp会T,因为n^2的。
考虑到每个节点的权重数值比较小,所以把一个数值拆分成二进制。(题目中,数值为二进制16位),我们可以拆成高8位,和低8位。
这样,做DP的时候可以记录一个信息,
假设当前DP的节点的权重是W,高8位为A,低8位为B,还有假设这个点的DP值已经算出来了是DP[p]
让H[A][i] 表示, 所有一个节点Q而言,这个Q的低8位为i。他的所有祖先中,如果某个祖先的高8位为A,则这个节点的取值可以是 H[A][i] + (A操作上Q的高8位)
有了这个辅助数组,转移就可以很方便的做到O (2^8)了。
但是,因为在树上做,所以H数组在做到某一个分支中,会被毁掉,所以在回溯的时候给复原就行啦。memcpy还是挺快的
#include <iostream>
#include <ctime>
#include <cstdio>
#include <cstdlib>
#include <algorithm>
#include <cstring>
#include <vector>
#include <map>
#include <string>
using namespace std;
typedef long long LL;
const int maxn = 1000000 + 100;
char opt[10];
int n;
vector<int>g[maxn];
LL w[maxn], f[maxn];
const LL inf = -(1LL<<60);
const LL mod = 1e9+7;
void init()
{
scanf("%d%s", &n, &opt);
for (int i = 1;i<=n;++i)
{
scanf("%lld", &w[i]);
g[i].clear();
}
for (int i = 2; i<= n; ++ i)
{
int k;
scanf("%d", &k);
g[k].push_back(i);
}
}
LL h[1<<8][1<<8];
const LL low8bit = (1<<8)-1;//低8位
LL backup[200000][1<<8];
LL cal(LL a, LL b)
{
if (opt[0] == 'A') return a & b;
if (opt[0] == 'X') return a ^ b;
return a | b;
}
void dfs(int now)
{
//backup
memcpy(backup[now], h[w[now]>>8], sizeof(backup[now]));
f[now]=0;
for (int i = 0; i <= low8bit; ++ i)
f[now] = max(h[i][w[now] & low8bit] + (cal(w[now]>>8, i)<<8), f[now]);
for (int i = 0; i <= low8bit; ++ i)
h[w[now]>>8][i] = max(h[w[now]>>8][i], f[now] + cal((w[now] & low8bit), i));
for (auto i : g[now]) dfs(i);
memcpy(h[w[now]>>8], backup[now], sizeof(backup[now]));
}
void doit()
{
for (int i = 0; i <= low8bit; ++ i)
for (int j = 0; j <= low8bit; ++ j)
h[i][j] = inf;
LL ans = 0;
dfs(1);
for (int i = 1; i <= n; ++ i)
ans = (ans + (f[i] +w[i]) * i) % mod;
printf("%lld\n", ans);
}
int main()
{
int T;
scanf("%d", &T);
while (T--)
{
init();
doit();
}
return 0;
}