题意:
给定 n n n 个字符串,问每个字符串 s i s_i si 有多少个子串至少在 k k k 串中出现过。 ( n , k , ∑ s i ≤ 1 0 5 ) (n, k, \sum s_i \leq 10^5) (n,k,∑si≤105)
链接:
https://vjudge.net/problem/HYSBZ-3277
解题思路:
如果是一个串求有多少子串出现过 k k k 次,那么就是对串建立后缀自动机,然后统计每个结点代表串的 r i g h t right right 集合大小,之后,要么枚举每个结点代表子串的答案(拓扑排序,子树和),要么可以枚举以 i i i 结尾的子串的答案(共 n n n 个右端点位置,分别有一条到根的链)。
现在是多串,那么建立广义后缀自动机,由于是至少 k k k 个串中出现,那么每个串至多提供 1 1 1 的贡献,先考虑求每个结点代表串的在多少个串中出现。将串 s i s_i si 的贡献看成权值为 i i i 的点数,相当于求子树内有多少不同的权值,可以线段树合并得到,或者通过 d f s dfs dfs 序维护树链的并,即 m m m 条根开始的链的并,这样复杂度都是 O ( n l o g n ) O(nlogn) O(nlogn)。
又或者,直接在自动机上暴力向上跳进行树链的并统计,这样复杂度是 O ( n n ) O(n\sqrt{n}) O(nn) 的。(简要复杂度证明: ∑ i = 1 n min { ∣ s i ∣ 2 , 2 ∗ ∑ ∣ s i ∣ } \sum\limits_{i = 1}^{n}\min\{~|s_i|^2,~2*\sum |s_i|~\} i=1∑nmin{ ∣si∣2, 2∗∑∣si∣ })
接下来,回到最初提到的两种做法,由于第一种需要对每个结点区分对某个串 s i s_i si 的贡献,很难统计,故使用第二种,对每个串 s i s_i si 枚举每个下标结尾的子串答案,即向上到根的一条树链,这样的贡献不需要区分是哪个串,预处理根到每个结点的树链的答案即可。
参考代码:
法一,求树链的并, O ( n l o g n ) O(nlogn) O(nlogn):
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 2e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
vector<int> G[maxn], pi[maxn];
char s[maxn];
int nxt[maxn][26], par[maxn], len[maxn];
int dfn[maxn], siz[maxn], fa[maxn], dep[maxn], top[maxn], son[maxn];
int dt[maxn], num[maxn]; ll sum[maxn], ans[maxn];
int n, k, cnt, last, tim;
int add(int l){
++cnt; len[cnt] = l; return cnt;
}
void init(){
cnt = 0; last = add(0);
}
void insert(char ch){
int t = ch - 'a', p = last, cur;
if(nxt[p][t]){
int q = nxt[p][t];
if(len[q] == len[p] + 1) { last = q; return; }
int nq = add(len[p] + 1); last = nq;
memcpy(nxt[nq], nxt[q], sizeof nxt[q]);
par[nq] = par[q], par[q] = nq;
while(p && nxt[p][t] == q) nxt[p][t] = nq, p = par[p];
return;
}
cur = last = add(len[p] + 1);
while(p && !nxt[p][t]) nxt[p][t] = cur, p = par[p];
if(!p) { par[cur] = 1; return; }
int q = nxt[p][t];
if(len[q] == len[p] + 1) { par[cur] = q; return; }
int nq = add(len[p] + 1);
memcpy(nxt[nq], nxt[q], sizeof nxt[q]);
par[nq] = par[q], par[q] = par[cur] = nq;
while(p && nxt[p][t] == q) nxt[p][t] = nq, p = par[p];
}
void dfs1(int u, int f){
fa[u] = f, dep[u] = dep[f] + 1, siz[u] = 1, son[u] = 0;
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i];
dfs1(v, u);
siz[u] += siz[v];
son[u] = siz[v] > siz[son[u]] ? v : son[u];
}
}
void dfs2(int u, int t){
dfn[u] = ++tim, top[u] = t;
if(!son[u]) return;
dfs2(son[u], t);
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i];
if(v == fa[u] || v == son[u]) continue;
dfs2(v, v);
}
}
int lca(int u, int v){
while(top[u] != top[v]){
dep[top[u]] > dep[top[v]] ? u = fa[top[u]] : v = fa[top[v]];
}
return dep[u] < dep[v] ? u : v;
}
int cmp(int x, int y){
return dfn[x] < dfn[y];
}
void dfs3(int u, int f){
sum[u] = sum[f] + (num[u] >= k) * (len[u] - len[par[u]]);
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i];
dfs3(v, u);
}
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
cin >> n >> k;
init();
for(int i = 1; i <= n; ++i){
cin >> s + 1;
for(int j = 1; s[j]; ++j){
insert(s[j]);
pi[i].pb(last);
}
last = 1;
}
for(int i = 2; i <= cnt; ++i) G[par[i]].pb(i);
dfs1(1, 0);
dfs2(1, 1);
for(int i = 1; i <= n; ++i){
sort(pi[i].begin(), pi[i].end(), cmp);
++dt[dfn[pi[i][0]]];
for(int j = 1; j < sz(pi[i]); ++j){
int u = pi[i][j], v = pi[i][j - 1];
int lc = lca(u, v);
++dt[dfn[u]], --dt[dfn[lc]];
}
}
for(int i = 1; i <= tim; ++i) dt[i] += dt[i - 1];
for(int i = 1; i <= cnt; ++i){
num[i] = dt[dfn[i] + siz[i] - 1] - dt[dfn[i] - 1];
}
dfs3(1, 0);
for(int i = 1; i <= n; ++i){
for(int j = 0; j < sz(pi[i]); ++j){
int v = pi[i][j];
ans[i] += sum[v];
}
}
for(int i = 1; i <= n; ++i) cout << ans[i] << " "; cout << endl;
return 0;
}
法二:暴力求并
O
(
n
n
)
O(n\sqrt{n})
O(nn):(没有刻意卡的话跑得飞快,又好写)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 2e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
vector<int> G[maxn], pi[maxn];
char s[maxn];
int nxt[maxn][26], par[maxn], len[maxn];
int vis[maxn], num[maxn], sum[maxn];
int n, k, cnt, last;
int add(int l){
++cnt; len[cnt] = l; return cnt;
}
void init(){
cnt = 0; last = add(0);
}
void insert(char ch){
int t = ch - 'a', p = last, cur;
if(nxt[p][t]){
int q = nxt[p][t];
if(len[q] == len[p] + 1) { last = q; return; }
int nq = add(len[p] + 1); last = nq;
memcpy(nxt[nq], nxt[q], sizeof nxt[q]);
par[nq] = par[q], par[q] = nq;
while(p && nxt[p][t] == q) nxt[p][t] = nq, p = par[p];
return;
}
cur = last = add(len[p] + 1);
while(p && !nxt[p][t]) nxt[p][t] = cur, p = par[p];
if(!p) { par[cur] = 1; return; }
int q = nxt[p][t];
if(len[q] == len[p] + 1) { par[cur] = q; return; }
int nq = add(len[p] + 1);
memcpy(nxt[nq], nxt[q], sizeof nxt[q]);
par[nq] = par[q], par[q] = par[cur] = nq;
while(p && nxt[p][t] == q) nxt[p][t] = nq, p = par[p];
}
void dfs(int u){
for(int i = 0; i < sz(G[u]); ++i){
int v = G[u][i];
sum[v] = sum[u] + (num[v] >= k) * (len[v] - len[u]);
dfs(v);
}
}
int main(){
ios::sync_with_stdio(0); cin.tie(0);
cin >> n >> k;
init();
for(int i = 1; i <= n; ++i){
cin >> s + 1;
for(int j = 1; s[j]; ++j){
insert(s[j]);
pi[i].pb(last);
}
last = 1;
}
for(int i = 1; i <= n; ++i){
for(int j = 0; j < sz(pi[i]); ++j){
int v = pi[i][j];
while(v && vis[v] != i) vis[v] = i, ++num[v], v = par[v];
}
}
for(int i = 2; i <= cnt; ++i) G[par[i]].pb(i);
dfs(1);
for(int i = 1; i <= n; ++i){
ll ret = 0;
for(int j = 0; j < sz(pi[i]); ++j){
int v = pi[i][j];
ret += sum[v];
}
cout << ret << " ";
}
cout << endl;
return 0;
}