Description:
题解:
首先思考一个经典的问题怎么做?
就是求一个串的子串的不同子串数,这玩意儿冬令营上《字符串前沿算法》讲过。
考虑在把整个串的SAM建出来,注意建的时候记录每个结尾点。
然后一个一个把结尾点激活,之所以要先建出来的原因是树的形态就稳固了。
假设要查询[l…r]的不同子串数,先把r以前的结尾点激活。
如果知道每个点的rmax,这个点的深度代表的是[x…y],那么就给[rmax-y+1…rmax-x+1]+1,然后区间查询[l…r]即可知道不同子串数。
考虑如何动态维护这个东西。
顺序激活,假设这次激活点x,那么x到根的这一条链的rmax都会变成当前的i。
考虑用lct结构维护,在一棵splay的点是一条链,且这一条链的rmax相同。
用lct的access去更改rmax,因为一条链的深度是一段区间
lct的access的势能分析复杂度是 O ( n l o g n ) O(n log n) O(nlogn)的,所以总复杂度是 O ( n l o g 2 n ) O(n log^2n) O(nlog2n)
然后开始无聊的猜结论。
把s变成最小完整周期串,这个跑个KMP就行了。
现在要证明不可能有两个左端点∈[1…s]长度>=n的串是相同的。
长度大于可以n可以视为=n,这是一样的,因为这是一个循环串。
设左端点为i,j
不难得到对任意x都有s[x]=s[x+|i-j|]
如何必然有一个长度为gcd(|i-j|,n)的周期,与题设矛盾。
这个可以用周期定理解释。
那么只需要需要求3n以内的答案,超过3n的话每次答案加n。
Code:
#include<cstdio>
#include<cstring>
#include<vector>
#define db double
#define pp printf
#define ll long long
#define fo(i, x, y) for(int i = x, B = y; i <= B; i ++)
using namespace std;
const int N = 6e5 + 5;
char s[N]; int n;
int nt[N];
int son[N][26], FA[N], dep[N], la, tot;
void build() {
fo(i, 1, tot) memset(son[i], 0, sizeof son[i]), FA[i] = 0;
la = tot = 1;
}
#define push(v) dep[++ tot] = v
void add(int c) {
int p = la;
push(dep[p] + 1); int np = tot;
for(; p && !son[p][c]; p = FA[p]) son[p][c] = np;
if(!p) FA[np] = 1; else {
int q = son[p][c];
if(dep[p] + 1 < dep[q]) {
push(dep[p] + 1); int nq = tot;
memcpy(son[nq], son[q], sizeof son[q]);
FA[nq] = FA[q]; FA[q] = FA[np] = nq;
for(; son[p][c] == q; p = FA[p]) son[p][c] = nq;
} else FA[np] = q;
}
la = np;
}
int ed[N];
int pl, pr, tt; ll px;
ll tr[N * 4]; int lz[N * 4];
#define i0 i + i
#define i1 i + i + 1
void down(int i, int x, int y) {
if(lz[i]) {
int m = x + y >> 1;
tr[i0] += lz[i] * (m - x + 1); lz[i0] += lz[i];
tr[i1] += lz[i] * (y - m); lz[i1] += lz[i];
lz[i] = 0;
}
}
void add(int i, int x, int y) {
if(y < pl || x > pr) return;
if(x >= pl && y <= pr) { tr[i] += (y - x + 1) * px, lz[i] += px; return ;}
int m = x + y >> 1; down(i, x, y);
add(i0, x, m); add(i1, m + 1, y);
tr[i] = tr[i0] + tr[i1];
}
void ft(int i, int x, int y) {
if(y < pl || x > pr) return;
if(x >= pl && y <= pr) { px += tr[i]; return ;}
int m = x + y >> 1; down(i, x, y);
ft(i0, x, m); ft(i1, m + 1, y);
}
int t[N][2], pf[N], fa[N], rm[N];
int lr(int x) { return t[fa[x]][1] == x;}
void rotate(int x) {
int y = fa[x], k = lr(x);
t[y][k] = t[x][!k]; if(t[x][!k]) fa[t[x][!k]] = y;
fa[x] = fa[y]; if(fa[y]) t[fa[y]][lr(y)] = x;
fa[y] = x; t[x][!k] = y; pf[x] = pf[y]; rm[x] = rm[y];
}
int fl(int x) {
return t[x][0] ? fl(t[x][0]) : x;
}
void splay(int x, int y) {
for(; fa[x] != y; rotate(x)) if(fa[fa[x]] != y)
rotate(lr(x) == lr(fa[x]) ? fa[x] : x);
}
void access(int x) {
for(int y = 0; x; ) {
splay(x, 0), fa[t[x][1]] = 0; rm[t[x][1]] = rm[x]; pf[t[x][1]] = x;
int z = fl(x); if(x != z) splay(z, x);
pl = rm[x] - dep[x] + 1; pr = rm[x] - (dep[FA[z]] + 1) + 1; px = -1;
add(1, 1, 3 * n);
t[x][1] = y; fa[y] = x; pf[y] = 0;
y = x; x = pf[x];
}
}
struct nod {
int x, y, i;
} A;
vector<nod> a[N];
int Q, x, y, ax[N], ay[N], bx[N], by[N];
ll ans[N];
int main() {
scanf("%s", s + 1); n = strlen(s + 1);
{
int x = 0;
fo(i, 2, n) {
while(x && s[x + 1] != s[i]) x = nt[x];
x += s[x + 1] == s[i]; nt[i] = x;
}
if(n % (n - nt[n]) == 0) n = n - nt[n];
}
build();
fo(ii, 1, 3) fo(i, 1, n) add(s[i] - 'a'), ed[(ii - 1) * n + i] = la;
fo(i, 1, tot) pf[i] = FA[i];
scanf("%d", &Q);
fo(i, 1, Q) {
scanf("%d %d", &x, &y);
int z = (x - 1) / n;
x -= z * n; y -= z * n;
if(y <= 3 * n) {
bx[i] = x; by[i] = y; A.i = i;
} else {
bx[i] = x; by[i] = x + 2 * n - 1; A.i = i;
}
A.x = bx[i]; A.y = by[i]; A.i = i;
a[A.y].push_back(A);
ax[i] = x; ay[i] = y;
}
fo(i, 1, 3 * n) {
int x = ed[i];
access(x); splay(x, 0); rm[x] = i;
pr = i; pl = i - dep[x] + 1; px = 1;
add(1, 1, 3 * n);
fo(j, 0, a[i].size() - 1) {
pl = a[i][j].x; pr = a[i][j].y; px = 0;
ft(1, 1, 3 * n); ans[a[i][j].i] = px;
}
}
fo(i, 1, Q) pp("%lld\n", ans[i] + (ll) (ay[i] - by[i]) * n);
}