#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353;
char s[1000100];
int main(){
ll n;
cin>>n;
cin>>s;
ll l=1,r=1;
for(int i=1;i<n;i++){
if(s[i]==s[i-1]){
l++;
}
else break;
}
for(int j=n-1;j>=1;j--){
if(s[j]==s[j-1]){
r++;
}
else break;
}
// cout<<l<<r<<endl;
if(s[0]!=s[n-1]){
cout<<(1+l+r)%mod<<endl;
}
else{
cout<<((l+r)%mod+l*r%mod+1)%mod<<endl;
}
}