后缀自动机入门题集

学了好几天后缀自动机,总算是真正搞懂了,才敢来发博客。

后缀自动机是啥以及怎么构造就不说了,毕竟有很多博客比我讲的好多了。
还是按照国际惯例,推荐几发:
hiho一下 127~132周
后缀自动机入门
史上最通俗的后缀自动机详解

先谈谈我对parent tree的理解:
首先由后缀链接构成的树就叫做parent tree(也叫后缀链接树)。
但是我们存的是反边,所以想dfs的时候就很不舒服,有两种解决办法:
1.倒着存正边(没见人用过)。
2.根据SAM的性质,子节点所代表的最长的字符串的长度一定大于父亲节点的,所以根据len的大小排序,然后从大往小更新即可。排序的时候用桶排序,可以省去一个log。
后缀链接的父亲节点的endpos完全包含子节点的endpos,且父亲节点所代表的字符串是子节点的后缀。
关于每一个节点存的endpos的数量,父亲节点的不一定恰好比子节点多1,而是子节点的endpos数量之和等于父亲节点的。

接下来是入门题:
luogu P3804 【模板】后缀自动机
确实是模板,我们建完后缀自动机后,求出每一个节点的endpos的数量,做法就是插入的同时标记一下这个节点是一个endpos。然后递归求一遍子树的endpos之和就是这个节点的endpos个数咯。之后如果siz[i] > 1,就用siz[i] * len[i]更新答案。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const int maxs = 30;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

char s[maxn];
struct Sam
{
  int las, cnt;
  int tra[maxn << 1][maxs], link[maxn << 1], len[maxn << 1];
  int siz[maxn << 1];
  In void init()
  {
    link[las = 0] = -1, len[cnt = 0] = 0;
  }
  In void insert(int c)
  {
    int now = ++cnt;
    len[now] = len[las] + 1;
    int p = las;
    while(p != -1 && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
    int q = tra[p][c];
    if(len[q] == len[p] + 1) link[now] = q;
    else
      {
        int clo = ++cnt;
        len[clo] = len[p] + 1;
        memcpy(tra[clo], tra[q], sizeof(tra[q]));
        link[clo] = link[q];
        link[q] = link[now] = clo;
        while(p != -1 && tra[p][c] == q) tra[p][c] = clo, p = link[p];
      }
      }
    siz[las = now] = 1;
  }
  int buc[maxn << 1], pos[maxn << 1];
  In int dfs()
  {
    int ret = 0;
    for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
    for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
    for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
    for(int i = cnt; i; --i)
      {
    int now = pos[i];
    siz[link[now]] += siz[now];
    if(siz[now] > 1) ret = max(ret, siz[now] * len[now]);
      }
    return ret;
  }
}S;

int main()
{
  scanf("%s", s);
  int n = strlen(s); S.init();
  for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
  write(S.dfs()), enter;
  return 0;
}


SP1811 LCS
求两个串的lcs。
把一个串建成后缀自动机,然后另一个串在上面跑,相当于枚举另一个串的前缀,看每一个前缀的后缀最多能和原串匹配多少。每成功匹配一个节点,就用当前匹配的长度更新答案。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2.5e5 + 5;
const int maxs = 30;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

char s1[maxn], s2[maxn];
struct Sam
{
  int las, cnt;
  int tra[maxn << 1][maxs], len[maxn << 1], link[maxn << 1];
  In void init() {link[las = cnt = 0] = -1;}
  In void insert(int c)
  {
    int now = ++cnt, p = las;
    len[now] = len[las] + 1;
    while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
    int q = tra[p][c];
    if(len[q] == len[p] + 1) link[now] = q;
    else
      {
        int clo = ++cnt;
        memcpy(tra[clo], tra[q], sizeof(tra[q]));
        len[clo] = len[p] + 1;
        link[clo] = link[q]; link[q] = link[now] = clo;
        while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
      }
      }
    las = now;
  }
  In int lcs(char* s)
  {
    int n = strlen(s), ret = 0;
    for(int i = 0, p = 0, l = 0; i < n; ++i)
      {
    int c = s[i] - 'a';
    if(tra[p][c]) ++l, p = tra[p][c];
    else
      {
        while(~p && !tra[p][c]) p = link[p];
        if(p == -1) l = p = 0;
        else l = len[p] + 1, p = tra[p][c];
      }
    ret = max(ret, l);
      }
    return ret;
  }
}S;

int main()
{
  scanf("%s%s", s1, s2);
  int n = strlen(s1); S.init();
  for(int i = 0; i < n; ++i) S.insert(s1[i] - 'a');
  write(S.lcs(s2)), enter;
  return 0;
}


SP1812 LCS2
求多个串的lcs。
还是先把一个串建成后缀自动机。
然后对于每一个串,都放在后缀自动机上跑,记录在每一个节点能匹配的最大长度。然后这些长度取min,就是所有串在每一个节点能匹配的最大长度。最后答案遍历每一个节点取max即可。
不过还得想一想的是,如果这个节点成功匹配了,那么他的所有祖先节点显然也是匹配了的,但是却只标记了这个节点。所以在每一个串跑完后缀自动机后,从叶子节点把把标记在上传一遍,更新所有祖先节点的匹配情况。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e5 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

char s[maxn];
struct Sam
{
  int las, cnt;
  int tra[maxn << 1][30], len[maxn << 1], link[maxn << 1];
  In void init() {link[las = cnt = 0] = -1;}
  In void insert(int c)
  {
    int now = ++cnt, p = las;
    len[now] = len[las] + 1;
    while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
    int q = tra[p][c];
    if(len[q] == len[p] + 1) link[now] = q;
    else
      {
        int clo = ++cnt;
        memcpy(tra[clo], tra[q], sizeof(tra[q]));
        len[clo] = len[p] + 1;
        link[clo] = link[q]; link[q] = link[now] = clo;
        while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
      }
      }
    las = now;
  }
  int buc[maxn << 1], pos[maxn << 1];
  In void sort()
  {
    for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
    for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
    for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
  }
  int Max[maxn << 1], Min[maxn << 1];
  In void lcs(char* s)
  {
    int n = strlen(s);
    for(int i = 0, p = 0, l = 0; i < n; ++i)
      {
    int c = s[i] - 'a';
    while(~p && !tra[p][c]) p = link[p], l = len[p];
    if(p == -1) p = l = 0;
    else ++l, p = tra[p][c], Max[p] = max(Max[p], l);
      }
    for(int i = cnt; i; --i)
      {
    int now = pos[i], fa = link[now];
    Max[fa] = max(Max[fa], min(Max[now], len[fa]));
    Min[now] = min(Min[now], Max[now]); Max[now] = 0;
      }
  }
}S;

int main()
{
  //freopen("ha.in", "r", stdin);
  scanf("%s", s);
  int n = strlen(s); S.init();
  for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
  S.sort();
  Mem(S.Min, 0x3f); Mem(S.Max, 0);
  while(scanf("%s", s) != EOF) S.lcs(s);
  int ans = 0;
  for(int i = 1; i <= S.cnt; ++i) ans = max(ans, S.Min[i]);
  write(ans), enter;
  return 0;
}


[USACO06DEC]Milk Patterns
找出现了至少\(k\)次的最长的子串。
建完后缀自动机,求一遍子树大小,然后如果大于\(k\)就更新好了。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
#include<map>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 2e4 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int n, K;

struct Sam
{
  int las, cnt;
  map<int, int> tra[maxn << 1];
  int len[maxn << 1], link[maxn << 1], siz[maxn << 1];
  In void init() {link[las = cnt = 0] = -1;}
  In void insert(int x)
  {
    int now = ++cnt, p = las;
    len[now] = len[las] + 1; siz[now] = 1;
    while(~p && !tra[p].count(x)) tra[p][x] = now, p = link[p];
    if(p == -1) link[now] = 0;
    else
      {
    int q = tra[p][x];
    if(len[q] == len[p] + 1) link[now] = q;
    else
      {
        int clo = ++cnt;
        tra[clo] = tra[q]; len[clo] = len[p] + 1;
        link[clo] = link[q];
        link[q] = link[now] = clo;
        while(~p && tra[p][x] == q) tra[p][x] = clo, p = link[p];
      }
      }
    las = now;
  }
  int pos[maxn << 1]; 
  In int dfs()
  {
    int ret = 0;
    for(int i = 1; i <= cnt; ++i) pos[i] = i;   //sort版,还用了lambda表达式……
    sort(pos + 1, pos + cnt + 1, [=](int& a, int& b) {return len[a] > len[b];});
    for(int i = 1; i <= cnt; ++i)
      {
    siz[link[pos[i]]] += siz[pos[i]];
    if(siz[pos[i]] >= K) ret = max(ret, len[pos[i]]);
      }
    return ret;
  }
}S;

int main()
{
  n = read(), K = read();
  S.init();
  for(int i = 1, x; i <= n; ++i) x = read(), S.insert(x);
  write(S.dfs()), enter;
  return 0;
}


[TJOI2015]弦论
求第\(k\)小的子串。
因为每一个子串代表一条路径,所以我们求出从每一个节点开始有多少条路径,然后像平衡树找第\(k\)大的方法找即可。
题目还分了两种情况:\(t\)为0的话,每一个节点的endpos数量显然就是0了;\(t\)为1的话,每一个节点的endpos数量就是子树大小了。
至于怎么求从每一个点开始的路径数量,上面第二篇博客有讲。

#include<cstdio>
#include<iostream>
#include<cmath>
#include<algorithm>
#include<cstring>
#include<cstdlib>
#include<cctype>
#include<vector>
#include<stack>
#include<queue>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 5e5 + 5;
inline ll read()
{
  ll ans = 0;
  char ch = getchar(), last = ' ';
  while(!isdigit(ch)) last = ch, ch = getchar();
  while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
  if(last == '-') ans = -ans;
  return ans;
}
inline void write(ll x)
{
  if(x < 0) x = -x, putchar('-');
  if(x >= 10) write(x / 10);
  putchar(x % 10 + '0');
}

int Flg, K;
char s[maxn];
struct Sam
{
  int las, cnt;
  int tra[maxn << 1][30], len[maxn << 1], link[maxn << 1], siz[maxn << 1];
  In void init() {link[las = cnt = 1] = -1;}
  In void insert(int c)
  {
    int now = ++cnt, p = las;
    len[now] = len[las] + 1; siz[now] = 1;
    while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
    if(p == -1) link[now] = 1;
    else
      {
    int q = tra[p][c];
    if(len[q] == len[p] + 1) link[now] = q;
    else
      {
        int clo = ++cnt;
        memcpy(tra[clo], tra[q], sizeof(tra[q]));
        len[clo] = len[p] + 1;
        link[clo] = link[q], link[q] = link[now] = clo;
        while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
      }
      }
    las = now;
  }
  int buc[maxn << 1], pos[maxn << 1], sum[maxn << 1];
  In void dfs()
  {
    for(int i = 1; i <= cnt; ++i) ++buc[len[i]];
    for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
    for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
    for(int i = cnt; i; --i) siz[link[pos[i]]] += siz[pos[i]];
    for(int i = 1; i <= cnt; ++i)
      {
    if(!Flg) sum[i] = siz[i] = 1;
    else sum[i] = siz[i];
      }
    siz[1] = 0;
    for(int i = cnt; i; --i)
      for(int j = 0; j < 26; ++j)
    if(tra[pos[i]][j]) sum[pos[i]] += sum[tra[pos[i]][j]];
  }
  In void print(int k)
  {
    if(sum[1] < k) {write(-1); return;}
    int now = 1;
    k -= siz[now];
    while(k)
      {
    int c = 0;
    while(k > sum[tra[now][c]]) k -= sum[tra[now][c++]];
    now = tra[now][c];
    putchar('a' + c); k -= siz[now];
      }
  }
}S;

int main()
{
  scanf("%s", s);
  int n = strlen(s); S.init();
  for(int i = 0; i < n; ++i) S.insert(s[i] - 'a');
  Flg = read(), K = read();
  S.dfs();
  S.print(K), enter;
  return 0;
}

转载于:https://www.cnblogs.com/mrclr/p/10450821.html

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值