大意就是给一个树,给很多个问询,选择两个点之间的若干个点,在保证选择一个点后与其直接相连的点不能选,每次询问求总和最大值。必须在线。
最开始想的是树剖,但是做完签到后一看K题只有几个AC就果断run了。
后来事实也证明不是树剖,应该用点分治+猫树。(如果没有dp部分的话我感觉我应该能写出来,我dp太菜了orz)
点分治可以处理掉树上的数据问题。
猫树是线段树的一个变种,以每个区间的中点,分别向左右两边进行DP储存好数据。算出当问询区间落在中点两段时所在的层数,就能实现O(1)查询
目前的水平还写不出来这道题(DP部分还没搞定是怎么操作的),所以DP部分只好看着标程用自己的码风写了一遍QWQ(不得不说标程的代码简直如同加密文件)
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5e5 + 10;
const int LOGN = 22;
struct Rand {
unsigned int n, seed;
Rand(unsigned int n, unsigned int seed): n(n), seed(seed) {}
int get(long long lastans) {
seed ^= seed << 13;
seed ^= seed >> 17;
seed ^= seed << 5;
return (seed ^ lastans) % n + 1;
}
};
struct DP {
ll v[2][2];
ll get()const {
return max(max(v[0][0], v[0][1]), max(v[1][0], v[1][1])); //四个选择的最大值
}
};
using merge_t = DP;
DP reverse(DP a) {
swap(a.v[0][1], a.v[1][0]);
return a;
}
int merge_count;
DP merge(merge_t a, merge_t b) {
DP res;
merge_count++;
for (int i = 0; i < 2; i++) {
for (int j = 0; j < 2; j++) {
res.v[i][j] = max(a.v[i][1] + b.v[0][j], max(a.v[i][0] + b.v[1][j], a.v[i][0] + b.v[0][j]));
}
}
return res;
}
int dfn[N * 2];
int dep[N * 2];
int LOG2[N * 2];//以2为底的log函数,比log(n)/log(2)更高效
int cnt;
int min_id(int a, int b) {
return dep[a] < dep[b] ? a : b;
}
int st[LOGN][N * 2]; //猫树用的st表
void push(int id) {
st[0][cnt] = id;
cnt++;
}
void ini() {
LOG2[0] = -1;
for (int i = 2; i < N * 2; i++) { //初始化LOG函数
LOG2[i] = LOG2[i >> 1] + 1;
}
for (int k = 1; k < LOGN; k++) { //ST表操作
for (int i = 0; i + (1 << k) < cnt; i++) {
st[k][i] = min_id(st[k - 1][i], st[k - 1][i + (1 << (k - 1))]);
}
}
}
ll a[N];//储存初始数据
vector<int>tr[N];//邻接表
DP from_root[LOGN][N];//不包含端点的最大值
DP to_root[LOGN][N];//包含端点的最大值
int size1[N];//以i点为源点的子树的大小
bool erased[N];//判断有无被删除
int tot_size;
int root_max_son;//重心最大子树的点数
int root;//重心
// 当前节点,父亲节点
int get_size(int now, int from = 0) { //获取一个子树的大小
size1[now] = 1;
for (int i = 0; i < tr[now].size(); i++) {
int to = tr[now][i];
if (!erased[to] && to != from) { //如果to点没有被删,且下一条边不是父系节点
size1[now] += get_size(to, now);
}
}
return size1[now];
}
void get_root(int now, int from = 0) {//获取重心
int max_son = tot_size - 1 - size1[now]; //
for (int i = 0; i < tr[now].size(); i++) {
int to = tr[now][i];
if (!erased[to] && to != from) {
get_root(to, now);
max_son = max(max_son, size1[to]); //寻找节点数最大的连通块
}
}
if (max_son < root_max_son) {
root = now;
root_max_son = max_son;
}
}
void build_dp(int now, int depth, int from) {
if (from == root) {
from_root[depth][now] = (DP) {
0, 0, 0, a[now]
};//如果是源点
} else { //如果不是源点,就DP
from_root[depth][now] = merge(from_root[depth][from], (DP) {
0, 0, 0, a[now]
});
}
to_root[depth][now] = merge((DP) {
0, 0, 0, a[now]
}, to_root[depth][from]);
for (int to : tr[now])
if (!erased[to] && to != from)
build_dp(to, depth, now);
}
//点分治
void divide(int now, int depth = 0) { //当前节点,当前深度(不传入就当作0)
tot_size = root_max_son = get_size(now);
get_root(now);//获取当前点的重心
now = root; //更新当前重心
dep[now] = depth;
dfn[now] = cnt;
push(now);
erased[now] = 1; //删除重心
from_root[depth][now] = {-LLONG_MAX, -LLONG_MAX, -LLONG_MAX, -LLONG_MAX};
to_root[depth][now] = {0, 0, 0, a[now]};
for (int to : tr[now]) {
if (!erased[to]) {
build_dp(to, depth, now);
}
}
for (int to : tr[now]) {
if (!erased[to]) {
divide(to, depth + 1);//分治
push(now);
}
}
}
ll query(int l, int r) {
if (l == r)
return a[l];
if (dfn[l] > dfn[r]) {
swap(l, r);
}
int lca = 1;
int k = LOG2[dfn[r] - dfn[l] + 1];
lca = min_id(st[k][dfn[l]], st[k][dfn[r] + 1 - (1 << k)]);
// 不包含端点 包含端点 获取最大
return max(merge(to_root[dep[lca]][l], from_root[dep[lca]][r]).get(), 0ll);
}
void solve() {
int n, m, seed;
cin >> n;
cin >> m;
cin >> seed;
for (int i = 1; i <= n; i++)
cin >> a[i];
for (int i = 1; i < n; i++) {
int u = i + 1;
int v;
cin >> v;
tr[u].push_back(v);
tr[v].push_back(u);
}
divide(1);//点分治
ini();//初始化LOG2数组和ST表
ll last_ans = 0; //用于更新
ll ans = 0;
constexpr int P = 998244353;
Rand rand(n, seed);
for (int i = 0; i < m; i++) {
int u = rand.get(last_ans);
int v = rand.get(last_ans);
int x = rand.get(last_ans);
last_ans = query(u, v);
ans = (ans + last_ans * x) % P;
}
cout << ans << endl;
}
int main() {
solve();
return 0;
}