(此题为付费比赛题目)
L.在这冷漠的世界里光光哭哭
小红、对立和光三人在玩一个游戏。
首先她们拿到了一个仅由小写字母构成的字符串。小红指定一个区间 [l,r],对立说出一个长度为 3 的字符串 s ,然后光需要回答出来,在这个区间中,一共有多少个 s 的子序列?
光感觉很蓝瘦,于是想来求助聪明的你,让你帮她回答问题。
输入描述:
第一行输入两个正整数 n 和 q ,分别代表字符串长度和询问次数。
第二行输入一个长度为 n 的、仅由小写字母构成的字符串t。
接下来的 q 行,每行输入两个正整数 l 和 r,以及一个长度为3的仅由小写字母构成的字符串 s。
数据范围:
1 ≤ n ≤ 80000 1 \leq n \leq 80000 1≤n≤80000
1 ≤ q ≤ 500000 1 \leq q \leq 500000 1≤q≤500000
输出描述:
对于每次询问,你需要帮光回答在该区间内包含 s 子序列的数量。
input
10 6
aabbbaacac
1 10 aba
1 10 abd
2 8 abc
5 10 bac
3 4 abc
1 10 aaa
output
18
0
3
5
0
10
很妙的一道容斥计数,简单来讲就是
a
n
s
=
(
三
个
字
母
都
在
【
1
,
r
】
区
间
的
子
序
列
)
−
(
三
个
字
母
都
在
【
1
,
l
−
1
】
的
子
序
列
)
−
(
前
两
个
字
母
在
【
1
,
l
−
1
】
最
后
一
个
字
母
在
【
l
,
r
】
)
−
(
第
一
个
字
母
在
【
1
,
l
−
1
】
,
后
两
个
字
母
在
【
l
,
r
】
)
ans=(三个字母都在【1,r】区间的子序列)- (三个字母都在【1,l-1】的子序列)- (前两个字母在【1,l - 1】最后一个字母在【l,r】) - (第一个字母在【1,l - 1】,后两个字母在【l,r】)
ans=(三个字母都在【1,r】区间的子序列)−(三个字母都在【1,l−1】的子序列)−(前两个字母在【1,l−1】最后一个字母在【l,r】)−(第一个字母在【1,l−1】,后两个字母在【l,r】)
考 虑 维 护 s u m 1 [ j ] [ j ] 表 示 i 前 面 有 多 少 字 符 j , 容 易 得 到 以 下 递 推 考虑维护sum1[ j ][ j ] 表示 i 前面有多少字符 j ,容易得到以下递推 考虑维护sum1[j][j]表示i前面有多少字符j,容易得到以下递推
s u m 1 [ j ] [ i ] = s u m 1 [ j ] [ i − 1 ] + ( s [ i ] = = j + ′ a ′ ) ; sum1[j][i]=sum1[j][i-1]+(s[i]==j+'a'); sum1[j][i]=sum1[j][i−1]+(s[i]==j+′a′);
先睡了后半部分有时间再更(摆了
AC代码
#include <bits/stdc++.h>
#define mod 1000000007
using namespace std;
int n,m,k,w,cnt,sum,l,r,q;
long long ans;
char s[80005];
long long sum1[30][80005],sum2[800][80005],sum3[800][80005];
vector<int> p[1005];
void solve(){
scanf("%d %d",&n,&q);
scanf("%s",s+1);
for(int i=0;i<26;i++){p[i].push_back(0);}
for(int i=1;i<=n;i++){
int u=p[s[i]-'a'].back();
p[s[i]-'a'].push_back(i);
for(int j=0;j<26;j++){
sum1[j][i]=sum1[j][i-1]+(s[i]==j+'a');
for(int k=0;k<26;k++){
sum2[j*26+k][i]=sum2[j*26+k][i-1]+(k==s[i]-'a')*(sum1[j][i-1]);
sum3[j*26+k][i]=sum3[j*26+k][u]+sum2[j*26+k][i-1];
}
}
}
while(q--){
scanf("%d %d %s",&l,&r,s);
m=s[2]-'a';
int m1=(sum1[m][r]-sum1[m][l-1]);
k=(s[0]-'a')*26+s[1]-'a';
w=(s[1]-'a')*26+s[2]-'a';
int u=(upper_bound(p[m].begin(),p[m].end(),r)-p[m].begin()-1),v=(lower_bound(p[m].begin(),p[m].end(),l)-p[m].begin()-1);
ans=sum3[k][p[m][u]]-sum3[k][p[m][v]];
ans-=m1*sum2[k][l-1];
ans-=sum1[s[0]-'a'][l-1]*(sum2[w][r]-sum2[w][l-1]-(m1*sum1[s[1]-'a'][l-1]));
printf("%lld\n",ans);
}
}
signed main(){
int t=1;
cout<<fixed<<setprecision(12);
while (t--)
solve();
}