题目链接
貌似有分治+DP的做法,但是我太弱了,只会写写可持久化平衡树的模板。
这道题如果想到了这种做法,其实就是几个可持久化平衡树的基本操作。
我们设
l
e
n
len
len为当前字符串的长度,我们会发现,我们进行
k
k
k次循环移位,其实就是将原串从右往左
k
m
o
d
l
e
n
k\space mod\space len
k mod len个位置分裂,然后合并到最左边,再与原串合并。
于是就是很明显的可持久化平衡树了。
查询的话,每个节点存储一下当前子树内所表达的串所含的每个字符的数量,然后就不难了,具体看看代码吧。
时间复杂度的话,建树
26
n
26n
26n。我们会发现我们最多只需要进行60次操作,这个串的长度就会乘2,所以最多只需要
l
o
g
2
r
log_2\space r
log2 r次操作就可以完成,更新节点信息需要更新26个字符的数量。每次询问的复杂度是
l
o
g
2
r
log_2\space r
log2 r,所以总时间复杂度是
O
(
26
n
+
26
l
o
g
2
2
r
+
q
l
o
g
2
r
)
O(26n+26log_2^2r+qlog_2\space r)
O(26n+26log22r+qlog2 r),顺利通过本题。
如果有误请在评论区说一声哦!
代码:
#include<cstdio>
#include<cctype>
#include<cstring>
#include<cstdlib>
typedef long long ll;
int n,m,len,cnt,ban,rt[70],ch[200010][2],val[200010];
ll siz[200010][27],k[200010];
char s[100010];
const ll maxn=1e18;
void maintain(int o){
for(int i=0;i<=26;i++)
siz[o][i]=siz[ch[o][0]][i]+siz[ch[o][1]][i]+(val[o]==i||i==26);
return;
}
void build(int &o,int l,int r){
if(l>r)
return;
o=++cnt;
int mid=(l+r)>>1;
val[o]=s[mid]-'a';
build(ch[o][0],l,mid-1);
build(ch[o][1],mid+1,r);
maintain(o);
return;
}
int cpy(int o){
val[++cnt]=val[o];
memcpy(siz[cnt],siz[o],sizeof(siz[o]));
ch[cnt][0]=ch[o][0];
ch[cnt][1]=ch[o][1];
return cnt;
}
void split(int o,ll k,int &a,int &b){
if(!o){
a=b=0;
return;
}
if(!k){
a=o,b=0;
return;
}
if(k==siz[o][26]){
a=0,b=o;
return;
}
int p=cpy(o);
if(k<=siz[ch[o][1]][26]){
a=p;
split(ch[p][1],k,ch[a][1],b);
}
else{
b=p;
split(ch[p][0],k-siz[ch[o][1]][26]-1,a,ch[b][0]);
}
maintain(p);
return;
}
int merge(int a,int b){
if(!a||!b)
return a|b;
int p;
if((((ll)rand()<<45ll)|((ll)rand()<<30ll)|((ll)rand()<<15ll)|(ll)rand())%(siz[a][26]+siz[b][26])<=siz[a][26]){
p=cpy(a);
ch[p][1]=merge(ch[p][1],b);
}
else{
p=cpy(b);
ch[p][0]=merge(a,ch[p][0]);
}
maintain(p);
return p;
}
ll query(int o,ll k,int c){
if(!o||!k)
return 0;
if(k==siz[o][26])
return siz[o][c];
if(k<=siz[ch[o][0]][26])
return query(ch[o][0],k,c);
return query(ch[o][1],k-siz[ch[o][0]][26]-1,c)+siz[ch[o][0]][c]+(val[o]==c);
}
ll rd(){
ll x=0;
char c;
do c=getchar();
while(!isdigit(c));
do{
x=(x<<1ll)+(x<<3ll)+(c^48);
c=getchar();
}while(isdigit(c));
return x;
}
int main(){
scanf("%s",s+1);
len=strlen(s+1);
n=rd(),m=rd();
for(int i=1;i<=n;i++)
k[i]=rd();
build(rt[0],1,len);
for(ban=1;ban<=n&&siz[rt[ban-1]][26]<maxn;ban++){
k[ban]%=siz[rt[ban-1]][26];
int a,b;
split(rt[ban-1],k[ban],a,b);
rt[ban]=merge(rt[ban-1],merge(b,a));
}
ban--;
while(m--){
ll l=rd(),r=rd();
char c;
do c=getchar();
while(!islower(c));
printf("%I64d\n",query(rt[ban],r,c-'a')-query(rt[ban],l-1,c-'a'));
}
return 0;
}