题目链接 , 题意已经很清楚了
类似 NOI2015 品酒大会 的做法
比那道题还要简单一些,只需要统计长度为
r
的相同子串的个数
ans=12∗(n−1)∗n∗(n+1)−2∗∑n−1r=1cnt[r]∗r
我是用并查集实现的,当然也可以用单调栈
//bzoj 3238
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <ctime>
#include <vector>
#include <utility>
#include <stack>
#include <queue>
#include <iostream>
#include <algorithm>
template<class Num>void read(Num &x)
{
char c; int flag = 1;
while((c = getchar()) < '0' || c > '9')
if(c == '-') flag *= -1;
x = c - '0';
while((c = getchar()) >= '0' && c <= '9')
x = (x<<3) + (x<<1) + (c-'0');
x *= flag;
return;
}
template<class Num>void write(Num x)
{
if(x < 0) putchar('-'), x = -x;
static char s[20];int sl = 0;
while(x) s[sl++] = x%10 + '0',x /= 10;
if(!sl) {putchar('0');return;}
while(sl) putchar(s[--sl]);
}
const int maxn = 500020;
char s[maxn];
int n, sa[maxn], rank[maxn];
int c[maxn], height[maxn];
int stack[maxn], top;
std::pair<int,int> add[maxn];
void build_sa(int m)
{
static int t0[maxn], t1[maxn];
int *x = t0, *y = t1;
for(int i = 1; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[x[i] = s[i]]++;
for(int i = 1; i <= m; i++) c[i] += c[i - 1];
for(int i = n; i >= 1; i--) sa[c[x[i]]--] = i;
for(int k = 1; k <= n; k <<= 1)
{
int p = 0;
for(int i = 0; i < k; i++) y[++p] = n - i;
for(int i = 1; i <= n; i++)
if(sa[i] > k) y[++p] = sa[i] - k;
for(int i = 1; i <= m; i++) c[i] = 0;
for(int i = 1; i <= n; i++) c[x[y[i]]]++;
for(int i = 1; i <= m; i++) c[i] += c[i - 1];
for(int i = n; i >= 1; i--) sa[c[x[y[i]]]--] = y[i];
std::swap(x, y), x[sa[p = 1]] = 1;
for(int i = 2; i <= n; i++)
x[sa[i]] = y[sa[i]] == y[sa[i - 1]] && sa[i] + k <= n && sa[i - 1] + k <= n && y[sa[i] + k] == y[sa[i - 1] + k] ? p : ++p;
if(p == n) break;
m = p;
}
}
void build_height()
{
int k = 0;
for(int i = 1; i <= n; i++) rank[sa[i]] = i;
for(int i = 1; i <= n; i++)
{
if(k != 0) k--;
if(rank[i] == 1) continue;
int j = sa[rank[i] - 1];
while(s[i + k] == s[j + k]) k++;
height[rank[i]] = k;
}
}
long long cnt, ans;
int fa[maxn], size[maxn];
int find(int x)
{
return x == fa[x] ? x : (fa[x] = find(fa[x]));
}
void gather(int x,int y,int r)
{
x = find(x), y = find(y);
if(x == y) return;
if(x > y) std::swap(x, y);
ans -= (long long)size[x] * size[y] * r << 1;
fa[y] = x, size[x] += size[y];
}
void solve()
{
ans = ((long long)(n + 1)*n>>1)*(n - 1);
for(int i = 1; i <= n; i++)
fa[i] = i, size[i] = 1;
for(int i = 1; i < n; i++)
add[i] = std::make_pair(height[i + 1], i);
std::sort(add + 1, add + n);
for(int i = n - 1; i >= 1; i--)
{
int v = add[i].second;
gather(v, v + 1, add[i].first);
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("3238.in","r",stdin);
freopen("3238.out","w",stdout);
#endif
scanf("%s", s + 1);
n = strlen(s + 1);
build_sa(256);
build_height();
solve();
write(ans);
#ifndef ONLINE_JUDGE
fclose(stdin);
fclose(stdout);
#endif
return 0;
}