题目:https://loj.ac/problem/6198
学习了学习了。
在sam上跑一个合并,将儿子节点的可用数字合并到parent树父亲节点,在01字典树上找异或最大值,再加上当前父亲节点表示的长度(lcp)来更新答案。
很简单明了的一个思路,但是有一个疑问,题目要求的是最长公共前缀,众所周知倒着建sam后树上两个节点的lcp就是其lca,那么这么一路向上合并,有很多时候我们走到的不是lca了,这种时候仿佛有不符合题意的嫌疑。
但是实际上,如果生成异或最大值的两个后缀在同一颗子树里,他们生成的答案在其lca位置是最大的,在之后的非lca位置生成的答案只会更小(parent树上表示的长度单调),不会影响到答案,而如果来自不同子树,当前的节点就是lca了。
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 5;
int n, w[maxn];
char s[maxn];
struct Trie {
int next[maxn * 20][2], tot;
int root[maxn << 1];
void add(int &rt, int x) {
rt = ++tot;
int p = rt;
for (int i = 16, c; i >= 0; --i) {
c = (x >> i) & 1;
if (!next[p][c]) {
next[p][c] = ++tot;
}
p = next[p][c];
}
}
int find(int p, int x) {
int ans = 0;
for (int i = 16, c; i >= 0; --i) {
c = (x >> i) & 1;
if (next[p][c ^ 1]) {
p = next[p][c ^ 1], ans = ans << 1 | 1;
} else {
p = next[p][c], ans = ans << 1;
}
}
return ans;
}
int merge(int x, int y) {
if (!x || !y) {
return x + y;
}
next[x][0] = merge(next[x][0], next[y][0]);
next[x][1] = merge(next[x][1], next[y][1]);
return x;
}
} trie;
struct Sam {
int next[maxn << 1][26];
int link[maxn << 1], step[maxn << 1];
vector<int> v[maxn << 1];
int a[maxn], b[maxn << 1];
int sz, last, root;
void init() {
//如多次建立自动机,加入memset操作
root = sz = last = 1;
}
void add(int c) {
int p = last;
int np = ++sz;
last = np;
step[np] = step[p] + 1;
while (!next[p][c] && p) {
next[p][c] = np;
p = link[p];
}
if (p == 0) {
link[np] = root;
} else {
int q = next[p][c];
if (step[p] + 1 == step[q]) {
link[np] = q;
} else {
int nq = ++sz;
memcpy(next[nq], next[q], sizeof(next[q]));
step[nq] = step[p] + 1;
link[nq] = link[q];
link[q] = link[np] = nq;
while (next[p][c] == q && p) {
next[p][c] = nq;
p = link[p];
}
}
}
}
void build() {
init();
for (int i = n; i > 0; i--) {
add(s[i] - 'a');
v[last].push_back(w[i]);
trie.add(trie.root[last], w[i]);
}
for (int i = 1; i <= sz; i++) {
a[step[i]]++;
}
for (int i = 1; i <= step[last]; i++) {
a[i] += a[i - 1];
}
for (int i = 1; i <= sz; i++) {
b[a[step[i]]--] = i;
}
int ans = 0;
for (int i = sz; i > root; --i) {
int e = b[i];
if (v[link[e]].size() < v[e].size()) {
swap(trie.root[link[e]], trie.root[e]);
swap(v[link[e]], v[e]);
}
int maxw = 0;
for (auto j : v[e]) {
v[link[e]].push_back(j);
maxw = max(maxw, trie.find(trie.root[link[e]], j));
}
trie.root[link[e]] = trie.merge(trie.root[link[e]], trie.root[e]);
ans = max(ans, step[link[e]] + maxw);
}
printf("%d\n", ans);
}
} sam;
int main() {
scanf("%d%s", &n, s + 1);
for (int i = 1; i <= n; ++i) {
scanf("%d", &w[i]);
}
sam.build();
return 0;
}