给定 n , q ( 2 e 5 ) n,q(2e5) n,q(2e5),和一个长为 n n n 的字符串 s s s。 q q q次询问,每次给定两个数组 a i , b i a_i,b_i ai,bi。
对于每次询问,输出 Σ Σ l c p ( a i , b i ) \Sigma\Sigma lcp(a_i,b_i) ΣΣlcp(ai,bi).
保证所有询问的 a a a和 b b b数组长度之和不超过 2 e 5 2e5 2e5.
先写一个暴力。
while(q--)
{
int la = read(), lb = read();
for(int i=1; i<=la; ++i) A[i] = read();//rk[read()-1];
for(int i=1; i<=lb; ++i) B[i] = read();//rk[read()-1];
ll ans = 0;
for(int i=1; i<=la; ++i)
for(int j=1; j<=lb; ++j)
ans += SA::lcp(A[i]-1, B[j]-1);
cout << ans << "\n";
}
做后缀数组,将给定的 a i , b i a_i,b_i ai,bi映射到 r a n k [ a i ] , r a n k [ b i ] rank[a_i],rank[b_i] rank[ai],rank[bi],问题就变成了:
每次给定两个数组 a i a_i ai, b i b_i bi,求 Σ Σ m i n j = a i + 1 b i { h e i g h t [ j ] } \Sigma\Sigma min_{j=a_i+1}^{b_i}\{height[j]\} ΣΣminj=ai+1bi{height[j]}.
单调栈
将单次询问的所有 a i , b i a_i,b_i ai,bi排好序,依次扫描,用两个单调栈 s t a c k _ a stack\_a stack_a, s t a c k _ b stack\_b stack_b维护历史最小值。
单调栈里的元素是 p a i r < v a l , c n t > pair<val, cnt> pair<val,cnt>,其中 v a l val val表示一个历史最小值, c n t cnt cnt表示这个值的出现次数, v a l val val从栈底到栈顶依次增大。
每当新扫描到一个值 k e y key key的时候,如果 k e y > t o p [ v a l ] key>top[val] key>top[val],就将 < k e y , 1 > <key,1> <key,1>压入栈顶;否则就将栈顶 < v a l , c n t > <val,cnt> <val,cnt>弹出,然后待插入元素变为 < k e y , c n t + 1 > <key, cnt+1> <key,cnt+1>,重复这个过程。
此外,需要动态维护这个单调栈的和。
操作序列
为了达到单组询问 O ( ∣ a ∣ + ∣ b ∣ ) O(|a|+|b|) O(∣a∣+∣b∣)而不是 O ( n ) O(n) O(n)的复杂度,需要对操作进行精细讨论。
用一个有序的vector<pair<int,int>>
来维护所有的操作,first是进行这个操作的时刻,second是进行这个操作的种类。
时刻注意,求 l c p lcp lcp时对 h e i g h t height height取 m i n min min的范围是 [ l + 1 , r ] [l+1,r] [l+1,r],而不是 [ l , r ] [l,r] [l,r].
也就是一个位置 a i a_i ai或 b i b_i bi应该分两个操作,查找和插入,查找在前插入在后且两个操作的时刻恰好差一。
操作种类划分为:0.a插入,1.b插入,2.a找值,3.b找值
插入的时刻是 本 身 位 置 + 1 本身位置+1 本身位置+1,插入的值是 h e i g h t [ 本 身 位 置 + 1 ] height[本身位置+1] height[本身位置+1],插到自己的单调栈里面,同时更新一下对面的单调栈。
找值的时刻是 本 身 位 置 本身位置 本身位置,找值的值是 m i n _ h e i g h t [ 上 一 次 操 作 的 时 刻 , 本 身 位 置 ] min\_height[上一次操作的时刻,本身位置] min_height[上一次操作的时刻,本身位置],同时更新一下对面。
自身的LCP
操作分割后,上述方法无法统计自身和自身的LCP。
解决的办法是,读入数据后,双指针扫一遍加上就可以。
总的复杂度 O ( n l o g n ) O(nlogn) O(nlogn),题解的做法是用map,复杂度 O ( l o g 2 n ) O(log^2n) O(log2n)。但是我的做法常数比较大,跑了快400ms。
总结:
- 这个神奇的单调栈是之前做洛谷一道题的时候想到的,一直没来得及写,后来学会SAM后把那道题用SAM秒了hhh。
- 一直感觉这个单调栈很难写,但真正写起来就十几行,很统一。
真正麻烦的地方在于操作序列以及相同位置的lcp,不过想清楚后也还好。struct MonoStack { ll sum = 0, r = 0; //栈和,栈指针 pair<ll,ll> save[M]; //val,cnt, val单调递增 void init(){sum = r = 0;} ll solve(ll key, ll add) //可以入栈,但没必要 { while(r && key<=save[r-1].first) { add += save[--r].second; sum -= save[r].first * save[r].second; } if(add) { sum += key * add; save[r++] = {key, add}; } return sum; } }st[2]; //0维护a,1维护b
- 试图使用SAM解这个题,但是没想到怎么用SAM求lcp,感觉这是我的一个知识盲区,之后关注一下。
- SA和SAM的板子该整理一下了,被费神嘲讽行数多,有一说一确实好多。
- 这个题收录一下,非常经典,考虑打印下来作板子用。
- 这场cf是我去年打的…当时只做了三个题,这道题是第7题orz.
完整代码
/* LittleFall : Hello! */
#include <bits/stdc++.h>
using namespace std; using ll = long long; inline int read();
const int M = 200016, MOD = 1000000007;
namespace SA
{
void st_init(int *arr, int n);
/* 后缀数组 */
int sa[M], rk[M], height[M]; //后缀三数组,sa和rk下标从0开始,height下标从1开始
int t1[M], t2[M], c[M]; // 用于基数排序的三个辅助数组
void build(char *str, int n, int m) // 构造后缀三数组,字符串下标从0开始,n表示长度,m表示字符集大小
{
str[n] = 0;
n++;
int i, j, p, *x = t1, *y = t2;
for(i = 0; i < m; i++) c[i] = 0;
for(i = 0; i < n; i++) c[x[i]=str[i]]++;
for(i = 1; i < m; i++) c[i] += c[i-1];
for(i = n-1; i >= 0; i--) sa[--c[x[i]]] = i;
for(j = 1; j <= n; j<<=1)
{
p = 0;
for(i = n-j; i < n; i++) y[p++] = i;
for(i = 0; i < n; i++) if(sa[i] >= j) y[p++] = sa[i]-j;
for(i = 0; i < m; i++) c[i] = 0;
for(i = 0; i < n; i++) c[x[y[i]]]++;
for(i = 1; i < m; i++) c[i] += c[i-1];
for(i = n-1; i >= 0; i--) sa[--c[x[y[i]]]] = y[i];
swap(x, y);
p = 1; x[sa[0]] = 0;
for(i = 1; i < n; i++)
x[sa[i]] = (y[sa[i-1]]==y[sa[i]]&&y[sa[i-1]+j]==y[sa[i]+j]) ? p-1 : p++;
if(p >= n) break;
m = p;
}
n--;
for(int i = 0; i <= n; i++) rk[sa[i]] = i;
for(int i=0, j=0, k=0; i < n; i++)
{
if(k) k--;
j = sa[rk[i]-1];
while(str[i+k]==str[j+k]) k++;
height[rk[i]] = k;
}
st_init(height, n);
}
/* ST表 */
int lg[M], _n;
int table[20][M];
void st_init(int *arr, int n)
{
_n = n;
if(!lg[0])
{
lg[0]=-1;
for(int i=1;i<M;i++)
lg[i]=lg[i/2]+1;
}
for(int i=1; i<=n; ++i)
table[0][i] = arr[i];
for(int i=1; i<=lg[n]; ++i)
for(int j=1; j<=n; ++j)
if(j+(1<<i)-1 <= n)
table[i][j] = min(table[i-1][j], table[i-1][j+(1<<(i-1))]);
}
int mih(int l, int r)
{
int t = lg[r-l+1];
return min(table[t][l], table[t][r-(1<<t)+1]);
}
};
struct MonoStack
{
ll sum = 0, r = 0; //栈和,栈指针
pair<ll,ll> save[M]; //val,cnt, val单调递增
void init(){sum = r = 0;}
ll solve(ll key, ll add) //可以入栈,但没必要
{
while(r && key<=save[r-1].first)
{
add += save[--r].second;
sum -= save[r].first * save[r].second;
}
if(add)
{
sum += key * add;
save[r++] = {key, add};
}
return sum;
}
}st[2]; //0维护a,1维护b
char str[M];
int A[M], B[M];
int main(void)
{
#ifdef _LITTLEFALL_
freopen("in.txt","r",stdin);
#endif
int n = read(), q = read();
scanf("%s",str);
SA::build(str, n, 128);
using SA::height; using SA::rk; using SA::mih;
// for(int i=1; i<=n; ++i)
// printf("%d ",height[i] );
// printf("\n");
while(q--)
{
st[0].init(), st[1].init();
int la = read(), lb = read();
for(int i=1; i<=la; ++i) A[i] = read()-1;// printf("rka=%d ",rk[A[i]] );
for(int i=1; i<=lb; ++i) B[i] = read()-1;// printf("rkb=%d ",rk[B[i]] );
//printf("\n");
//首先把所有的相同位置都算一遍
ll ans = 0;
for(int i=1, j=1; i<=la && j<=lb; )
{
if(A[i]==B[j]) ans += n-A[i], ++i, ++j;
else if(A[i]<B[j]) ++i;
else ++j;
}
vector<pair<int,int>> ops; //时刻,操作种类
//操作种类有:0.a插入,1.b插入,2.a找值,3.b找值,插入应优先于找值
//插入的时刻是本身位置+1,插入的值是height[本身位置+1],插到自己里面,同时更新一下对面
//找值的时刻是本身位置,找值的值是min_height[上一次操作的时刻,本身位置],同时更新一下对面
for(int i=1; i<=la; ++i) ops.emplace_back(rk[A[i]]+1, 0), ops.emplace_back(rk[A[i]], 2);
for(int i=1; i<=lb; ++i) ops.emplace_back(rk[B[i]]+1, 1), ops.emplace_back(rk[B[i]], 3);
sort(ops.begin(), ops.end());
for(int i=0, lst=ops[0].first; i<(int)ops.size(); ++i)
{
int ti = ops[i].first, op = ops[i].second;
if(op<2) //插入
{
st[op].solve(height[ti], 1);
st[op^1].solve(height[ti], 0);
}
else //查值
{
ans += st[3-op].solve(mih(lst,ti), 0);
st[op-2].solve(mih(lst,ti), 0);
}
//printf("ops: ti=%d op=%d ans=%I64d\n",ti,op,ans );
lst = ti;
}
cout << ans << "\n";
}
return 0;
}
inline int read(){
int x=0,f=1;char ch=getchar();
while(ch<'0'||ch>'9') {if(ch=='-')f=-1;ch=getchar();}
while(ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}