Codeforces Round #129 (Div. 1)E. Little Elephant and Strings
题意:给出n个字符串,问每个字符串有多少个子串(不同位置,相同的子串视作不同)至少出现在这n个字符串中的k个当中。
解法:这题学到了一个SAM的新技能,对于这多个串,建SAM的时候,不是把它们连在一起,建立SAM,而是先给它们建立Trie树,然后广搜这棵Trie树,对于Trie树上的V节点,在建SAM的时候,它应该接在Trie树上他的父亲节点后面,我们用TtoM[U]表示Trie树上的U节点映射到SAM上的标号。这样建立SAM的优点是,我找任何一个字符串的任何一个前缀,它匹配的的SAM上的节点的代表串必然是这个前缀。我们先记住这个东西,怎么用等会儿看。我们要求的是每个字符串的所有子串至少出现在K个字符串中,那么我们先看看所有的子串中,有哪些子串是出现在了k个字符串中,表达在SAM上就是有哪些节点被K个字符串匹配到过。我们用cnt[u]表示u节点被几个字符串匹配过。我们每次拿出一个字符串,它能给一些节点的cnt[]值贡献1,这些节点,就是这个字符串的每个前缀在sam中的节点到根的链的并集,这个用LCA求就好了。统计完cnt[]之后,看每个节点的cnt值是否大于等于k,是的话,这个节点u上就有val[u]-val[fa[u]]个子串是被k个以上字符串包含的,用add[u]表示这个值。最后,算每个字符串的答案的时候,就是这个字符串的每个前缀映射到SAM上的节点到根的链上的add之和。
代码:
#include<stdio.h>
#include<algorithm>
#include<string.h>
#include<vector>
#include<queue>
#define ll __int64
using namespace std ;
const int maxn = 111111 ;
const int N = maxn << 1;
struct Edge {
int to , next ;
} edge[N] ;
int head[N] , tot , f[N<<1] ;
void new_edge ( int a , int b ) {
edge[tot].to = b ;
edge[tot].next = head[a] ;
head[a] = tot ++ ;
}
char s[maxn] , s1[maxn] ; int l[maxn] , len ;
int TtoM[maxn<<1] ;
struct LCA {
int dp[22][N<<1] ;
int to[N] , tim[N] ;
int tot , n ;
int MIN ( int a , int b ) {
return tim[a] < tim[b] ? a : b ;
}
void init () {
tot = 0 ;
n = 0 ;
}
void dfs ( int u , int fa ) {
tim[u] = ++ tot ;
for ( int i = head[u] ; i != -1 ; i = edge[i].next ) {
int v = edge[i].to ;
if ( v == fa ) continue ;
dfs ( v , u ) ;
dp[0][++n] = u ;
}
dp[0][++n] = u ;
to[u] = n ;
}
void rmq () {
for ( int i = 1 ; i <= 20 ; i ++ ) {
for ( int j = 1 ; j + (1<<i) - 1 <= n ; j ++ ) {
dp[i][j] = MIN ( dp[i-1][j] , dp[i-1][j+(1<<i-1)] ) ;
}
}
}
int query ( int a , int b ) {
a = to[a] , b = to[b] ;
if ( a > b ) swap ( a , b ) ;
int k = b - a + 1 ;
return MIN ( dp[f[k]][a] , dp[f[k]][b-(1<<f[k])+1] ) ;
}
} lca ;
namespace SAM {
int fa[N] , val[N] , c[26][N] ;
int cnt[N] ; int tot , last ;
int ws[N] , wv[N] ;
ll add[N] ;
vector<int> vec[N] ;
void init () ;
void solve ( int , int ) ;
inline int new_node ( int _val ) {
val[++tot] = _val ;
for ( int i = 0 ; i < 26 ; i ++ ) c[i][tot] = 0 ;
cnt[tot] = fa[tot] = add[tot] = 0 ;
vec[tot].clear () ;
return tot ;
}
int ADD ( int k , int p ) {
int i ;
int np = new_node ( val[p] + 1 ) ;
while ( p && !c[k][p] ) c[k][p] = np , p = fa[p] ;
if ( !p ) fa[np] = 1 ;
else {
int q = c[k][p] ;
if ( val[q] == val[p] + 1 ) fa[np] = q ;
else {
int nq = new_node ( val[p] + 1 ) ;
for ( i = 0 ; i < 26 ; i ++ )
c[i][nq] = c[i][q] ;
fa[nq] = fa[q] ;
fa[q] = fa[np] = nq ;
while ( p && c[k][p] == q ) c[k][p] = nq , p = fa[p] ;
}
}
return np ;
}
void SORT () {
for ( int i = 0 ; i < maxn ; i ++ ) wv[i] = 0 ;
for ( int i = 1 ; i <= tot ; i ++ ) wv[val[i]] ++ ;
for ( int i = 1 ; i < maxn ; i ++ ) wv[i] += wv[i-1] ;
for ( int i = 1 ; i <= tot ; i ++ ) ws[wv[val[i]]--] = i ;
}
}
namespace Trie {
int c[26][maxn] , tot ;
int new_node () {
for ( int i = 0 ; i < 26 ; i ++ )
c[i][tot] = 0 ;
return tot ++ ;
}
void init () {
tot = 0 ;
new_node () ;
}
void insert ( int n ) {
for ( int i = 1 ; i <= n ; i ++ ) {
int now = 0 ;
for ( int j = l[i] ; j < l[i+1] ; j ++ ) {
int k = s[j] - 'a' ;
if ( !c[k][now] ) c[k][now] = new_node () ;
now = c[k][now] ;
}
}
}
}
queue<int> Q ;
void SAM::init () {
tot = 0 ;
TtoM[0] = new_node ( 0 ) ;
Q.push ( 0 ) ;
#define v Trie::c[k][u]
while ( !Q.empty () ) {
int u = Q.front () ; Q.pop () ;
for ( int k = 0 ; k < 26 ; k ++ )
if ( v ){
TtoM[v]=ADD(k,TtoM[u]) ;
Q.push ( v ) ;
}
}
}
int cmp ( int a , int b ) {
return lca.tim[a] < lca.tim[b] ;
}
int sta[maxn] ;
void SAM::solve ( int n , int k ) {
SORT () ;
for ( int i = 2 ; i <= tot ; i ++ ) {
new_edge ( fa[i] , i ) ;
}
lca.dfs ( 1 , 0 ) ; lca.rmq () ;
for ( int i = 1 ; i <= n ; i ++ ) {
int u = 0 ;
int top = 0 ;
for ( int j = l[i] ; j < l[i+1] ; j ++ ) {
int k = s[j] - 'a' ;
u = v ;
sta[++top] = TtoM[u];
cnt[TtoM[u]] ++ ;
}
sort ( sta + 1 , sta + top + 1 , cmp ) ;
for ( int j = 2 ; j <= top ; j ++ ) {
int w = lca.query ( sta[j-1] , sta[j] ) ;
cnt[w] -- ;
}
}
for ( int i = tot ; i >= 1 ; i -- ) {
int p = ws[i] ;
cnt[fa[p]] += cnt[p] ;
if ( cnt[p] >= k ) add[p] = val[p] - val[fa[p]] ;
}
for ( int i = 1 ; i <= tot ; i ++ ) {
int u = ws[i] ;
for ( int j = head[u] ; j != -1 ; j = edge[j].next ) {
int to = edge[j].to ;
add[to] += add[u] ;
}
}
for ( int i = 1 ; i <= n ; i ++ ) {
int u = 0 ; ll ans = 0 ;
for ( int j = l[i] ; j < l[i+1] ; j ++ ) {
int k = s[j] - 'a' ;
u = v ;
ans += add[TtoM[u]] ;
}
printf ( "%I64d " , ans ) ;
}
puts ( "" ) ;
}
#undef v
void init () {
tot = 0 ;
memset ( head , -1 , sizeof ( head ) ) ;
lca.init () ;
Trie::init () ;
}
int main () {
f[0] = -1 ;
for ( int i = 1 ; i < maxn << 2 ; i ++ )
f[i] = f[i>>1] + 1 ;
int n , k ;
while ( scanf ( "%d%d" , &n , &k ) != EOF ) {
init () ;
len = 0 ;
for ( int i = 1 ; i <= n ; i ++ ) {
scanf ( "%s" , s1 ) ;
int k = strlen ( s1 ) ;
l[i] = len ;
for ( int j = 0 ; j < k ; j ++ )
s[len++] = s1[j] ;
}
l[n+1] = len ;
Trie::insert (n) ;
SAM::init () ;
SAM::solve ( n , k ) ;
}
return 0 ;
}
/*
3 2
abc
bc
ab
3 2
abc
ac
ab
2 2
abc
bc
1 1
bc
2 2
ab
b
4 4
abab
baba
aaabbbababa
abababababa
2 2
abab
baba
2 2
aba
bab
2 2
ab
ba
*/