题目链接:http://nanti.jisuanke.com/t/11144
题目概述:给定一个字符串,求不相交的两个最大循环节相同的子串的对数。
样例解释:
这题的描述读起来很怪,但细想也确实传达了他想表达的意思。
比如abcd这个串,实际上任意的子串最大循环节都是1,所以一共有15对目标组合:
(a,b),(a,c),(a,d),(b,c),(b,d),(c,d),(ab,cd),(abc,d),(a,bcd),(b,cd),(ab,c),(ab,d),(a,cd),(a,bc),(bc,d)
思路概述:
先只考虑重复次数至少为2的子串,枚举长度 L,那么重复次数至少为2的子串必定包括了 S_0,S_L,S_{L*2} …中某相邻的两个,所以只需要看看 S_{L*i}和 S_{L*(i+1)}往前和往后最多能匹配到多远。这一步可以使用后缀数组完成,而不需要一位一位去挪。正向反向分别看两个相邻的L区间起始的后缀的最长公共前缀是多少即可。这里有一点需要注意,一个区间一旦在一个小的L值时被选择,对于一个新的较大的L,如果它也选择这个区间,那么直接忽略就好,因为由它分割的区间一定不是最大循环节了。
假设往左能匹配到left,往右能匹配到right,那么可以处理出一些子串区间 Q(l,r,x,L)表示区间(l,r)中,任意连续的x*L子串,都是由长度为L的一个子子串循环x次构成的。所有这些区间个数最多有 nlogn 个。区间搞出来后,对于x相同的区间,将所有区间插入线段树,在线段树里维护cnt_i(记为S字段)和 i*cnt_i(记为S2字段),其中cnt_i表示右端点为i的区间个数。然后枚举所有区间,则简单的推导可以得到(记y=r-l+2-x*L表示一个区间实际上所表示的子串的数量)在该区间左边的区间个数:
于是对于重复次数至少为2的相同重复次数的子串对的个数已经求出来了。
然后根据已知的所有重复次数至少为2的区间可以求出 cntl_i和 cntr_i,分别表示重复次数等于1的左端点为i的子串个数和右端点为i的子串个数,结合线段树,以及以每个点i(1..len)为左端点和右端点的区间数分别为(len+1-i)和i。那么重复次数为1的子串对就很容易可以求出了。
总体复杂度为 O(nlogn)。
#include <bits/stdc++.h>
#define pb push_back
#define mp make_pair
#define F(x) ((x) / 3 + ((x) % 3 == 1 ? 0 : tb))
#define G(x) ((x) < tb ? (x) * 3 + 1 :((x) - tb) * 3 + 2)
#define lson l , m , rt << 1
#define rson m + 1 , r , rt << 1 | 1
using namespace std;
const double eps(1e-8);
typedef long long ll;
typedef pair<int,int> pii;
#define ws Ws
#define rank rrrank
const int maxn = 100100;
const int MAXN = maxn*3;
const int MAXM = maxn;
struct Node
{
int l,r,x,L;
Node() {}
Node(int l,int r,int x,int L):l(l),r(r),x(x),L(L) {}
bool operator < (const Node &o) const
{
if(x != o.x) return x<o.x;
if(l != o.l) return l<o.l;
return r<o.r;
}
};
vector<Node> nodes;
int ntot = 0;
int rank[MAXN], height[maxn];
int rank2[MAXN], height2[maxn];
int wa[MAXN],wb[MAXN],ws[MAXN],wv[MAXN],wsd[MAXN],r[MAXN],sa[MAXN];
int c0(int *r,int a,int b)
{
return r[a] == r[b] && r[a + 1] == r[b + 1] && r[a + 2] == r[b + 2];
}
int c12(int k,int *r,int a,int b)
{
if(k == 2) return r[a] < r[b] || r[a] == r[b] && c12(1,r,a + 1,b + 1);
else return r[a] < r[b] || r[a] == r[b] && wv[a + 1]< wv[b + 1];
}
void sort(int *r,int *a,int *b,int n,int m)
{
int i;
for(i = 0 ; i < n ; i++) wv[i] = r[a[i]];
for(i = 0 ; i < m ; i++) wsd[i] = 0;
for(i = 0 ; i < n ; i++) wsd[wv[i]]++;
for(i = 1 ; i < m ; i++) wsd[i] += wsd[i - 1];
for(i = n - 1 ; i >= 0 ; i--) b[--wsd[wv[i]]] = a[i];
}
void dc3(int *r,int *sa,int n,int m)
{
int i,j,*rn = r + n ,*san = sa + n,ta = 0,tb = (n + 1) / 3,tbc = 0,p;
r[n] = r[n + 1] = 0;
for(i = 0 ; i < n ; i++) if(i % 3 != 0) wa[tbc++] = i;
sort(r + 2,wa,wb,tbc,m);
sort(r + 1,wb,wa,tbc,m);
sort(r,wa,wb,tbc,m);
for(p = 1,rn[F(wb[0])] = 0,i = 1 ; i < tbc ; i++)
rn[F(wb[i])] = c0(r,wb[i - 1],wb[i])?p - 1 : p++;
if(p < tbc) dc3(rn,san,tbc,p);
else for(i = 0 ; i < tbc ; i++) san[rn[i]] = i;
for(i = 0 ; i < tbc ; i++) if(san[i] < tb) wb[ta++] = san[i] * 3;
if(n % 3 == 1) wb[ta++] = n - 1;
sort(r,wb,wa,ta,m);
for(i = 0 ; i < tbc ; i++) wv[wb[i] = G(san[i])] = i;
for(i = 0,j = 0,p = 0 ; i < ta && j < tbc ; p++)
sa[p]=c12(wb[j] % 3,r,wa[i],wb[j]) ? wa[i++] : wb[j++];
for(; i < ta ; p++) sa[p] = wa[i++];
for(; j < tbc ; p++) sa[p] = wb[j++];
}
void calheight(int *r, int *sa, int n)
{
int i, j, k = 0;
for(i = 1; i <= n; i++) rank[sa[i]] = i;
for(i = 0; i < n; height[rank[i++]] = k)
for(k ? k-- : 0, j = sa[rank[i] - 1]; r[i + k] == r[j + k]; k++);
return;
}
void calheight2(int *r, int *sa, int n)
{
int i, j, k = 0;
for(i = 1; i <= n; i++) rank2[sa[i]] = i;
for(i = 0; i < n; height2[rank2[i++]] = k)
for(k ? k-- : 0, j = sa[rank2[i] - 1]; r[i + k] == r[j + k]; k++);
return;
}
int dp[maxn][25];
void initRMQ(int n)
{
for(int i = 1; i <= n; i++) dp[i][0] = height[i];
for(int j = 1; (1 << j) <= n; j++)
for(int i = 1; i + (1 << j) - 1 <= n; i++)
dp[i][j] = min(dp[i][j - 1], dp[i + (1 << (j - 1))][j - 1]);
return;
}
int dp2[maxn][25];
void initRMQ2(int n)
{
for(int i = 1; i <= n; i++) dp2[i][0] = height2[i];
for(int j = 1; (1 << j) <= n; j++)
for(int i = 1; i + (1 << j) - 1 <= n; i++)
dp2[i][j] = min(dp2[i][j - 1], dp2[i + (1 << (j - 1))][j - 1]);
return;
}
int askRMQ(int a, int b)
{
int ra = rank[a], rb = rank[b];
if(ra > rb) swap(ra, rb);
int k = 0;
while((1 << (k + 1)) <= rb - ra) k++;
return min(dp[ra + 1][k], dp[rb - (1 << k) + 1][k]);
}
int len;
int askRMQ2(int a, int b)
{
a = len - a - 1;
b = len - b - 1;
int ra = rank2[a], rb = rank2[b];
if(ra > rb) swap(ra, rb);
int k = 0;
while((1 << (k + 1)) <= rb - ra) k++;
return min(dp2[ra + 1][k], dp2[rb - (1 << k) + 1][k]);
}
char in[maxn];
int s[MAXN], lef[maxn], rig[maxn];
ll sumr[maxn];
map<pii,int> lrx;
set<pii> yes;
vector<Node> sx;
ll S[maxn<<2],S2[maxn<<2],add[maxn<<2];
bool clr[maxn<<2];
void PushUp(int rt)
{
S[rt] = S[rt<<1] + S[rt<<1|1];
S2[rt] = S2[rt<<1] + S2[rt<<1|1];
}
void PushDown(int rt,int l,int r)
{
int m = r - l + 1;
if(clr[rt])
{
add[rt<<1] = add[rt<<1|1] = 0;
S[rt<<1] = S[rt<<1|1] = 0;
S2[rt<<1] = S2[rt<<1|1] = 0;
clr[rt<<1] = clr[rt<<1|1] = true;
clr[rt] = false;
}
if(add[rt])
{
add[rt<<1] += add[rt];
add[rt<<1|1] += add[rt];
S[rt<<1] += add[rt] * (m - (m >> 1));
S[rt<<1|1] += add[rt] * (m >> 1);
int mid = l+r>>1;
S2[rt<<1] += (1LL*mid*(mid+1)/2 - 1LL*l*(l-1)/2)*add[rt];
S2[rt<<1|1] += (1LL*r*(r+1)/2 - 1LL*mid*(mid+1)/2)*add[rt];
add[rt] = 0;
}
}
void update(int L,int R,int c,int l,int r,int rt)
{
if (L <= l && r <= R)
{
add[rt] += c;
S[rt] += 1LL * c * (r - l + 1);
S2[rt] += (1LL*r*(r+1)/2 - 1LL*l*(l-1)/2)*c;
return ;
}
PushDown(rt , l, r);
int m = (l + r) >> 1;
if (L <= m) update(L , R , c , lson);
if (m < R) update(L , R , c , rson);
PushUp(rt);
}
ll queryS(int L,int R,int l,int r,int rt)
{
if (L <= l && r <= R)
{
return S[rt];
}
PushDown(rt, l, r);
int m = (l + r) >> 1;
ll ret = 0;
if (L <= m) ret += queryS(L , R , lson);
if (m < R) ret += queryS(L , R , rson);
return ret;
}
ll queryS2(int L,int R,int l,int r,int rt)
{
if (L <= l && r <= R)
{
return S2[rt];
}
PushDown(rt, l, r);
int m = (l + r) >> 1;
ll ret = 0;
if (L <= m) ret += queryS2(L , R , lson);
if (m < R) ret += queryS2(L , R , rson);
return ret;
}
vector<int> vec;
ll Work()
{
ll ret = 0;
S[1] = S2[1] = add[1] = 0;
clr[1] = true;
int vsiz = sx.size();
for(int i=0; i<vsiz; i++)
{
update(sx[i].l+1,sx[i].r+1,1,1,len,1);
}
for(int i=0; i<vsiz; i++)
{
ll X = sx[i].l+2-sx[i].x*sx[i].L;
ll L = sx[i].r-sx[i].l+1;
ll tmp = 0;
if(X > 1) tmp = queryS(1,X-1,1,len,1)*L;
if(L >= 2){
tmp += (L+X-1)*queryS(X,X+L-2,1,len,1);
tmp -= queryS2(X,X+L-2,1,len,1);
}
ret += tmp;
}
return ret;
}
int main()
{
scanf("%s", in);
len = strlen(in);
for(int i = 0; i < len; i++)
s[i] = in[i] - 'a' + 1;
s[len] = 0;
dc3(s, sa, len + 1, 28);
calheight(s, sa, len);
initRMQ(len);
for(int i = 0; i < len; i++)
s[i] = in[len - 1 - i] - 'a' + 1;
s[len] = 0;
dc3(s, sa, len + 1, 28);
calheight2(s, sa, len);
initRMQ2(len);
for(int L = 1; L <= (len >> 1); L++)//枚举循环节长度
{
int n = len;
int blocks = n / L + (n % L != 0);
int now = 1;
while(now < blocks)
{
if(now + 1 < blocks)
{
int len2 = askRMQ2((now + 1)*L - 1, now*L - 1);
int len1 = askRMQ(now*L, (now + 1)*L);
int totlen = L + len2 + len1;
int cnt = totlen / L;
if(!yes.count(mp(now*L-len2,(now+1)*L+len1-1)))
{
yes.insert(mp(now*L-len2,(now+1)*L+len1-1));
for(int i = 2; i <= cnt; i++)
{
if(now*L-len2+i*L-1<=(now+1)*L+len1-1) nodes.pb(Node(now*L-len2+i*L-1,(now+1)*L+len1-1,i,L));
}
}
now = now + (len1 / L) + 1;
}
else
{
if(n % L != 0) break;
int len2 = askRMQ2(len - 1, now*L - 1);
int totlen = (len - now*L) + len2;
int cnt = totlen / L;
if(!yes.count(mp(now*L-len2,len-1)))
{
yes.insert(mp(now*L-len2,len-1));
for(int i = 2; i <= cnt; i++)
{
if(now*L-len2+i*L-1<=len-1)nodes.pb(Node(now*L-len2+i*L-1,len-1,i,L));
}
}
now = blocks;
}
}
}
ntot = nodes.size();
sort(nodes.begin(),nodes.end());
int cur = 0;
ll ans = 0;
for(int i=0; ntot && i<=ntot; i++)
{
if(i==ntot || nodes[i].x != cur)
{
if(cur) ans += Work();
sx.clear();
cur = nodes[i].x;
}
if(i<ntot) sx.pb(nodes[i]);
}
S[1] = S2[1] = add[1] = 0;
clr[1] = true;
for(int i=0; i<ntot; i++)
{
update(nodes[i].l+1,nodes[i].r+1,1,1,len,1);
}
sumr[0] = 1;
for(int i=1; i<len; i++)
{
sumr[i] = i+1-queryS(i+1,i+1,1,len,1)+sumr[i-1];
}
S[1] = S2[1] = add[1] = 0;
clr[1] = true;
for(int i=0; i<ntot; i++)
{
update(nodes[i].l-nodes[i].x*nodes[i].L+2,nodes[i].r-nodes[i].x*nodes[i].L+2,1,1,len,1);
}
for(int i=0; i<len; i++)
{
lef[i] = queryS(i+1,i+1,1,len,1);
}
for(int i=1; i<len; i++)
{
ans += 1LL*sumr[i-1]*(len-i-lef[i]);
}
cout<<ans<<endl;
return 0;
}