链接:https://nanti.jisuanke.com/t/41353
思路:思路是真难说,但是我觉得我发了代码大家很容易就能理解了,不懂可以评论问一下,会回复的,我去补别的题目去了,这个题目感觉不是很难的sam的运用,只要推出期望式子加上经典的sam套路,赛时读题面扫一眼没读,以为是一个超难的神仙题QAQ。
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
const int maxn=1e6+50;
const int maxm=3e5+50;
const ll modd =998244353,inf = 0x3f3f3f3f,INF = 0x7fffffff;
inline ll min(ll a,ll b){return a < b ? a : b;}
inline ll max(ll a,ll b){return a > b ? a : b;}
inline ll gcd(ll a,ll b){ return b==0? a: gcd(b,a%b); }
inline ll exgcd(ll a,ll b,ll &x,ll &y){ ll d; (b==0? (x=1,y=0,d=a): (d=exgcd(b,a%b,y,x),y-=a/b*x)); return d; }
inline ll qpow(ll a,ll n){ll sum=1;while(n){if(n&1)sum=sum*a%modd;a=a*a%modd;n>>=1;}return sum;}
inline ll qmul(ll a,ll n){ll sum=0;while(n){if(n&1)sum=(sum+a)%modd;a=(a+a)%modd;n>>=1;}return sum;}
inline ll inv(ll a) {return qpow(a,modd-2);}
inline ll madd(ll a,ll b){return (a%modd+b%modd)%modd;}
inline ll mmul(ll a,ll b){return a%modd * b%modd;}
int t;
ll arr[maxn];
int l,k,n,m;
char s[maxn];
ll rec[maxn];
ll qk(ll a,ll b,ll p)
{
ll ans=1%p;
while(b)
{
if(b&1) ans=ans*a%p;
a=a*a%p;
b=b>>1;
}
return ans;
}
ll fx(ll x)
{
ll ans=n-x+1;
ll aa=qk(26,x,mod);
ans=mmul(ans,inv(aa))%mod;
ll zz=x;
ll xx=arr[0];
for(int i=1;i<=k;i++)
{
xx=(xx+((arr[i]%mod)*zz%mod)%mod)%mod;
zz=zz*x;
zz=zz%mod;
}
ans=mmul(ans,xx);
return ans;
}
void init(int n)
{
for(int i=1;i<=n;i++)
{
rec[i]=(rec[i-1]+fx(i))%mod;
}
}
struct SAM
{
int next[maxn][30],fa[maxn],len[maxn];
int root,tot,last;
int val(int x)
{
return len[x]-len[fa[x]];
}
int newnode(int l)
{
fa[tot]=0;
for(int i=0;i<30;++i) next[tot][i]=0;
len[tot++]=l; return tot-1;
}
void init()
{
tot=1;
last=root=newnode(0);
}
ll extend(int x)
{
int p=last; int cur=newnode(len[p]+1);
while(p!=0&&next[p][x]==0)
{
next[p][x]=cur; p=fa[p];
}
if(p==0) fa[cur]=root;
else
{
int q=next[p][x];
if(len[q]==len[p]+1) fa[cur]=q;
else
{
int tmp = newnode(len[p]+1);
memcpy(next[tmp],next[q],sizeof(next[q]));
fa[tmp]=fa[q]; fa[q]=fa[cur]=tmp;
while(p!=0&&next[p][x]==q){
next[p][x]=tmp; p=fa[p];
}
}
}
last=cur;
if(len[fa[last]]>=n) return 0;
else
{
ll ff=len[fa[last]];
ll cnt=min(n,len[last]);
ll zz=(rec[cnt]-rec[ff]+mod)%mod;
return zz;
}
}
}sam;
int main()
{
scanf("%d",&t);
while(t--)
{
scanf("%d%d%d%d",&l,&k,&n,&m);
scanf("%s",s+1);
for(int i=0;i<=k;i++) scanf("%lld",&arr[i]);
int mi=min(l+m,n);
init(mi);
sam.init();
ll fin=0;
for(int i=1;i<=l;i++)
{
ll zz=sam.extend(s[i]-'a');
fin=(fin+zz)%mod;
}
printf("%lld\n",fin);
for(int i=1;i<=m;i++)
{
char ch[5];
scanf("%s",ch);
ll zz=sam.extend(ch[0]-'a');
fin=(fin+zz)%mod;
printf("%lld\n",fin);
}
}
return 0;
}