Given a string, we need to find the total number of its distinct substrings.
Input
T- number of test cases. T<=20;
Each test case consists of one string, whose length is <= 1000
Output
For each test case output one number saying the number of distinct substrings.
Example
Sample Input:
2
CCCCC
ABABA
Sample Output:
5
9
Explanation for the testcase with string ABABA:
len=1 : A,B
len=2 : AB,BA
len=3 : ABA,BAB
len=4 : ABAB,BABA
len=5 : ABABA
Thus, total number of distinct substrings is 9.
题意:
统计不同子串个数
首先,所有的子串可以看成所有后缀子串的前缀的总和,表示为总长度 - 后缀起点 + 1,由于sa数组 保存的是串的下标,所以写成 n - sa[ i ] 。
下一步,如何消去重复的,利用height数组,记录的是相邻后缀子串相同前缀的长度,也就是相邻后缀子串中重复的子串,将其减去,就是不同的子串个数。
推出公式 n - sa[ i ] - height[ i ] i = 1...n ( height[ 1 ] = 0 所以不影响 )
AC代码
#include<iostream>
#include<algorithm>
#include<cstdio>
#include<cstring>
#define rep(i,a,n) for(int i=a;i<n;i++)
#define per(i,a,n) for(int i=n-1;i>=a;i--)
#define fori(x) for(int i=0;i<x;i++)
#define forj(x) for(int j=0;j<x;j++)
#define memset(x,y) memset(x,y,sizeof(x))
#define memcpy(x,y) memcpy(x,y,sizeof(y))
#define sca(x) scanf("%d", &x)
#define scas(x) scanf("%s",x)
#define sca2(x,y) scanf("%d%d",&x,&y)
#define sca3(x,y,z) scanf("%d%d%d",&x,&y,&z)
#define scl(x) scanf("%lld",&x)
#define scl2(x,y) scanf("%lld%lld",&x,&y)
#define scl3(x,y,z) scanf("%lld%lld%lld",&x,&y,&z)
#define pri(x) printf("%d\n",x)
#define pri2(x,y) printf("%d %d\n",x,y)
#define pris(x) printf("%s\n",x)
#define prl(x) printf("%lld\n",x)
typedef long long ll;
const int maxn=1e6+7;
const int mod=1e9+7;
const double eps=1e-8;
using namespace std;
const int N = 2000 + 10;
int sa[N], height[N], rnk[N], c[N], wa[N], wb[N];
bool cmp(int *r, int a, int b, int l)
{
return r[a] == r[b] && r[a+l] == r[b+l];
}
void Rsort(int *x, int *y, int n, int m)
{
for(int i = 0; i < m; i++) c[i] = 0;
for(int i = 0; i < n; i++) c[x[y[i]]]++;
for(int i = 1; i < m; i++) c[i] += c[i-1];
for(int i = n-1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
}
void da(int *s, int n, int m)
{
int *x = wa, *y = wb;
for(int i = 0; i < n; i++) x[i] = s[i], y[i] = i;
Rsort(x, y, n, m);
for(int j = 1, p = 1; p < n; j *= 2, m = p)
{
p = 0;
for(int i = n-j; i < n; i++) y[p++] = i;
for(int i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i] - j;
Rsort(x, y, n, m);
swap(x, y);
p = 1; x[sa[0]] = 0;
for(int i = 1; i < n; i++) x[sa[i]] = cmp(y, sa[i-1], sa[i], j) ? p-1 : p++;
}
}
void get_height(int *s, int n)
{
int i, j, k = 0;
for(i = 0; i <= n; i++) rnk[sa[i]] = i;
for(i = 0; i < n; height[rnk[i++]] = k)
for(k ? --k : 0, j = sa[rnk[i]-1]; s[i+k] == s[j+k]; k++);
}
char s[N];
int a[N];
int main()
{
int t;
sca(t);
while(t--)
{
int ma = -1;
scas(s);
int n = strlen(s);
rep(i,0,n)
{
a[i] = s[i];
ma = max(ma,a[i]);
}
a[n] = 0;
da(a,n+1,ma+1);
get_height(a,n);
int res = n - sa[1];
rep(i,2,n+1)
{
res += (n - sa[i] - height[i]);
}
pri(res);
}
return 0;
}