题目:https://loj.ac/problem/6031
题解:对k进行分治。
k小时用莫队维护处[a,b]中每个区间贡献次数c[l][r],
再暴力枚举区间统计答案。
k大时询问次数少,所以对于每个询问串预处理匹配,再暴力枚举【a,b】统计答案
相当于在pnt树上倍增跳找合适长度的right集合
注意:跳pnt后匹配长度不一定等于当前节点的val,需要对已匹配长度取min
莫队很容易打错边界,特别是l和r的贡献到底统计没。并且cnt数组要清空
sam写在struct里面,注意last=np不能是nq
数据分治的题分开调,写了2h10分钟我真的好菜,QaQ
#include<bits/stdc++.h>
using namespace std;
#define maxn 200020
typedef long long ll;
int n,m,q,k,sz,cnt[1020][1020],match[maxn],len[maxn];
char ch[maxn];
ll ans[maxn];
struct SAM{
int next[maxn][26],tot,pnt[maxn],val[maxn],last,sz[maxn];
int a[maxn],b[maxn],jump[20][maxn];
void Add(int x){ //插入一个字符
int np = ++tot ,p = last;
val[np] = val[p] + 1 , sz[np] = 1; //只有np节点有sz
while ( !next[p][x] && p ) next[p][x] = np , p = pnt[p];
int q = next[p][x];
if ( !q ) pnt[np] = p , next[p][x] = np;
else if ( val[p] + 1 == val[q] ) pnt[np] = q;
else{
int nq = ++tot;
val[nq] = val[p] + 1;
pnt[nq] = pnt[q];
pnt[np] = pnt[q] = nq;
memcpy(next[nq],next[q],sizeof(next[q]));
while ( next[p][x] == q && p ) next[p][x] = nq , p = pnt[p];
if ( next[p][x] == q ) next[p][x] = nq;
}
last = np;
}
void getsize(){ //桶排序(相同于拓扑序),求right集合大小
for (int i = 1 ; i <= tot ; i++) a[val[i]]++;
for (int i = 1 ; i <= n ; i++) a[i] += a[i - 1];
for (int i = tot ; i >= 1 ; i--) b[a[val[i]]--] = i;
for (int i = tot ; i >= 1 ; i--) sz[pnt[b[i]]] += sz[b[i]];
}
void pre(){ //预处理倍增
for (int i = 1 ; i <= tot ; i++) jump[0][i] = pnt[i];
for (int i = 1 ; i <= 17 ; i++)
for (int j = 1 ; j <= tot ; j++)
jump[i][j] = jump[i - 1][jump[i - 1][j]];
}
void print(){ //打印
for (int i = 1 ; i <= tot ; i++) cout<<i<<" "<<pnt[i]<<" "<<sz[i]<<endl;
}
}sam;
struct node{
int l,r;
}dt[maxn];
struct node2{
int a,b,id;
char ch[1020];
bool operator < (node2 x)const{
return b < x.b;
}
};
vector <node2> Q[120];
//k小时莫队,暴力统计答案,O(nsqrt(n) + nk)
void solve1(){
for (int i = 1 ; i <= q ; i++){
node2 cur; cur.id = i;
scanf("%s",cur.ch + 1);
scanf("%d %d",&cur.a,&cur.b) , cur.a++ , cur.b++;
Q[cur.a / sz].push_back(cur);
}
for (int i = 0 ; i <= n / sz + 1 ; i++){
sort(Q[i].begin(),Q[i].end());
for (register int x = 1 ; x <= k ; x++) for (register int y = x ; y <= k ; y++) cnt[x][y] = 0;
int l = 0 , r = 0;
for (int j = 0 ; j < Q[i].size() ; j++){
while ( r < Q[i][j].b ) ++r , cnt[dt[r].l][dt[r].r]++;
if ( l < Q[i][j].a ) while ( l < Q[i][j].a ) cnt[dt[l].l][dt[l].r]-- , ++l;
else while ( l > Q[i][j].a ) --l , cnt[dt[l].l][dt[l].r]++;
//getans
for (register int x = 1 ; x <= k ; x++){
int cur = 0;
for (register int y = x ; y <= k ; y++){
cur = sam.next[cur][Q[i][j].ch[y] - 'a'];
if ( !cur ) break;
ans[Q[i][j].id] += (ll)sam.sz[cur] * cnt[x][y];
}
}
}
}
for (int i = 1 ; i <= q ; i++) printf("%lld\n",ans[i]);
}
//k较大,对字符串预处理
int getans(int x,int len){
if ( !x || sam.val[x] < len ) return 0;
for (int i = 17 ; i >= 0 ; i--){
if ( sam.val[sam.jump[i][x]] >= len ) x = sam.jump[i][x];
}
return sam.sz[x];
}
void solve2(){
for (int i = 1 ; i <= q ; i++){
int a,b,cur = 0; ll ans = 0;
scanf("%s",ch + 1) , scanf("%d %d",&a,&b) , a++ , b++;
for (int j = 1 ; j <= k ; j++){
int tmp = len[j - 1];
//求每个子串匹配最长长度
while ( cur && !sam.next[cur][ch[j] - 'a'] ) cur = sam.pnt[cur] , tmp = min(tmp,sam.val[cur]);
if ( sam.next[cur][ch[j] - 'a'] ){
cur = sam.next[cur][ch[j] - 'a'];
match[j] = cur , len[j] = tmp + 1;
}
else len[j] = 0;
}
for (register int j = a ; j <= b ; j++){
register int l = dt[j].l , r = dt[j].r;
if ( len[r] < r - l + 1 ) continue;
ans += getans(match[r],r - l + 1);
}
printf("%lld\n",ans);
}
}
void init(){
scanf("%s",ch + 1);
for (int i = 1 ; i <= m ; i++) scanf("%d %d",&dt[i].l,&dt[i].r) , dt[i].l++ , dt[i].r++;
for (int i = 1 ; i <= n ; i++) sam.Add(ch[i] - 'a');
sam.getsize();
sam.pre();
// sam.print();
}
int main(){
// freopen("input.txt","r",stdin);
scanf("%d %d %d %d",&n,&m,&q,&k);
sz = min(1000,(int)sqrt(n) * 3);
init();
if ( k < sz ) solve1();
else solve2();
return 0;
}