2018 ICPC 南京 全文见:https://blog.csdn.net/qq_43461168/article/details/112796538
M. Mediocre String Problem
题意:给定两个串s和t。三元组(i,j,k)表示 s[i-j] + t[0-k],有多少个三元组使得结果串是回文串。简单来说就是在s中找一个子串 ,与t的前缀连起来变成回文串。求方案数。
!思路参考:https://www.cnblogs.com/luowentao/p/10332309.html
思路1:个人觉得思路2更好理解得多。可以先看思路2。把s[i-j] + t[0-k] 拆成三部分看。 s1 + s2 + t1。其中s2不为空。因为题目要求 s1+s2 > t1。s1 = reverse(t1)。那么就要求s2为回文串。也就是要找出 s串的所有回文。然后枚举s1的左端点。回文串可以用Manacher求出来(也还有很多别的方法,回文自动机,字符串哈希)。然后要找s1 == reverse(t1)这个东西。 因为t1就是t的前缀,而s1可以看成是s串的某个后缀的前缀。也就是和t求LCP。如果把s串倒过来。 然后再和t跑一遍EXKMP。刚好就可以求出所有的LCP了。然后就是组合回文串和LCP。遍历一遍原串。以每一个位置i,当成是(i,j,k)里面的j。然后去枚举k。其实枚举k和枚举i是一样的。只需要枚举一个。那其实也就是枚举LCP。如果线性的去枚举。那肯定超时了。因为刚刚已经计算过所有位置的LCP。那么对LCP求一个前缀和。然后枚举的时候,只需要求 i 点 回文半径+1 内的所有LCP之和,就相当于枚举了所有的s2 与 s1。就可以求出答案了。难点是下标的计算。
AC代码1:
#include <iostream>
#include <bits/stdc++.h>
#include <unordered_map>
#define int long long
#define mk make_pair
#define gcd __gcd
using namespace std;
const double eps = 1e-10;
const int mod = 998244353;
const int N = 5e6+7;
int n,m,k,t = 1,cas = 1;
int a[N],b[N],c[N];
const int maxn=3e6+9; //字符串长度最大值
int nex[maxn],ex[maxn]; //ex数组即为extend数组
//预处理计算nex数组
void GETNEXT(char *str)
{
int i=0,j,po,len=strlen(str);
nex[0]=len;//初始化nex[0]
while(str[i]==str[i+1]&&i+1<len)//计算nex[1]
i++;
nex[1]=i;
po=1;//初始化po的位置
for(i=2;i<len;i++)
{
if(nex[i-po]+i<nex[po]+po)//第一种情况,可以直接得到nex[i]的值
nex[i]=nex[i-po];
else//第二种情况,要继续匹配才能得到nex[i]的值
{
j=nex[po]+po-i;
if(j<0)j=0;//如果i>po+nex[po],则要从头开始匹配
while(i+j<len&&str[j]==str[j+i])//计算nex[i]
j++;
nex[i]=j;
po=i;//更新po的位置
}
}
}
//计算extend数组
void EXKMP(char *s1,char *s2)
{
int i=0,j,po,len=strlen(s1),l2=strlen(s2);
GETNEXT(s2);//计算子串的nex数组
while(s1[i]==s2[i]&&i<l2&&i<len)//计算ex[0]
i++;
ex[0]=i;
po=0;//初始化po的位置
for(i=1;i<len;i++)
{
if(nex[i-po]+i<ex[po]+po)//第一种情况,直接可以得到ex[i]的值
ex[i]=nex[i-po];
else//第二种情况,要继续匹配才能得到ex[i]的值
{
j=ex[po]+po-i;
if(j<0)j=0;//如果i>ex[po]+po则要从头开始匹配
while(i+j<len&&j<l2&&s1[j+i]==s2[j])//计算ex[i]
j++;
ex[i]=j;
po=i;//更新po的位置
}
}
}
char Ma[maxn*2]; // #号填充原串
int Mp[maxn*2]; // 回文半径
void Manacher(char s[],int len){
int l=0;
Ma[l++]='$';
Ma[l++]='#';
for(int i=0;i<len;i++){
Ma[l++]=s[i];
Ma[l++]='#';
}
Ma[l]=0;
int mx=0,id=0; // mx 最右回文右边界
for(int i=0;i<l;i++){
Mp[i]=mx>i?min(Mp[2*id-i],mx-i):1;
while(Ma[i+Mp[i]]==Ma[i-Mp[i]])Mp[i]++;
if(i+Mp[i]>mx){
mx=i+Mp[i];
id=i;
}
}
}
int sum[N];
int cal(int l,int r){
if(l > r) return 0;
if(l <= 0) return sum[r];
return sum[r]-sum[l-1];
}
char s1[N],s2[N];
signed main(){
cin>>s1>>s2;
int len1 = strlen(s1);
Manacher(s1,len1);
reverse(s1,s1+len1);
EXKMP(s1,s2);
reverse(ex,ex+len1);
sum[0] = ex[0];
for(int i = 1; i < len1 ; i ++) sum[i] = sum[i-1]+ex[i];
int res = 0;
for(int i = 2 ; i < 2*len1 + 3 ; i ++){
int now = Mp[i]-1; // 回文半径
if(now <= 0) continue;
if(now%2 == 1){ // 奇数长度
int center = (i-2)/2; // 计算回文中心
int r = center-1;
int l = center-Mp[i]/2;
res += cal(l,r);
}else{
int center = (i-2-1)/2;
int r = center-1;
int l = center-now/2;
res += cal(l,r);
}
}
cout<<res<<endl;
return 0;
}
思路2:第一步同思路1,把s[i-j]分成s1 和 s2 去枚举。不过这次是。固定 s1 的右端点。也是就说 s[i-j] 分成 s[i-p] + s[p+1-j] ,枚举这个p。p固定之后。i,j怎么枚举呢。i其实就是LCP的最大长度。也就是 EX[p]。因为每个长度都可以贡献一次嘛。而j就简单了。 j的个数 其实就是 以p+1为起点的回文串的个数。顺着求不是很好求。 但是如果把他倒过来看。也就是把s倒过来算。 那就是以p+1为终点的回文串的个数。也就是回文后缀的个数。这个就是PAM回文树求的东西。因为fail指针就是一个回文后缀。那么fail指针的高度。就是回文后缀的个数。也就是说在插入过程中。记录这个 height[p+1] 就是 以p+1为起点的回文串的个数。有点绕。 总之就是 把s串倒过来。 然后依次插入 回文树。 模板的num[i] 记录的就是回文后缀的个数。求完之后。再把他 num数组倒回来。就行了。然后求解答案的时候。枚举到 p,贡献就是 ex[p]*num[p+1]。
AC代码2:
#include <iostream>
#include <bits/stdc++.h>
#include <unordered_map>
#define int long long
#define mk make_pair
#define gcd __gcd
using namespace std;
const double eps = 1e-10;
const int mod = 998244353;
const int N = 2e6+7;
int n,m,k,t = 1,cas = 1;
int a[N],b[N],c[N];
int mark[N];
char s[N];
struct PAM{
/**
len[u] : u 节点代表回文串的长度。
fa[u] : u 节点代表回文串的最长回文后缀代表的节点。
tran[u][c] : 转移函数,表示在 u 代表的回文串的两端加上字符 c 之后的回文串。
num[u] : 代表 u 节点代表回文串的回文后缀个数。
L[i] : 代表原字符串以 i 结尾的回文后缀长度。
size[u] : u 点代表的回文串的数量。
**/
int len[N],fa[N],size[N],num[N],tot,last,trans[N][27],L[N];
int cnt[N];
void init(){ // 初始化
len[0]=0;fa[0]=1;len[1]=-1;fa[1]=0;
tot=1;last=0;
memset(trans[1],0,sizeof(trans[1]));
memset(trans[0],0,sizeof(trans[0]));
}
int new_node(int x){ // 建立新节点
int now=++tot;
memset(trans[tot],0,sizeof(trans[tot]));
len[now]=x;
return now;
}
int ins(int c,int n){ // 增量法构造
int u=last;
while(s[n-len[u]-1]!=s[n])u=fa[u];
if(trans[u][c]==0){
int now=new_node(len[u]+2);
int v=fa[u];
while(s[n-len[v]-1]!=s[n])v=fa[v];
fa[now]=trans[v][c];
trans[u][c]=now;
num[now]=num[fa[now]]+1;
}
last=trans[u][c];size[last]++;
L[n]=len[last];
cnt[n] = num[last];
return num[last];
}
void build(char *s){
int len = strlen(s);
for(int i = 0 ; i < len ; i ++){
ins(s[i]-'a',i);
}
}
}pam;
const int maxn=3e6+9; //字符串长度最大值
int nex[maxn],ex[maxn]; //ex数组即为extend数组
//预处理计算nex数组
void GETNEXT(char *str)
{
int i=0,j,po,len=strlen(str);
nex[0]=len;//初始化nex[0]
while(str[i]==str[i+1]&&i+1<len)//计算nex[1]
i++;
nex[1]=i;
po=1;//初始化po的位置
for(i=2;i<len;i++)
{
if(nex[i-po]+i<nex[po]+po)//第一种情况,可以直接得到nex[i]的值
nex[i]=nex[i-po];
else//第二种情况,要继续匹配才能得到nex[i]的值
{
j=nex[po]+po-i;
if(j<0)j=0;//如果i>po+nex[po],则要从头开始匹配
while(i+j<len&&str[j]==str[j+i])//计算nex[i]
j++;
nex[i]=j;
po=i;//更新po的位置
}
}
}
//计算extend数组
void EXKMP(char *s1,char *s2)
{
int i=0,j,po,len=strlen(s1),l2=strlen(s2);
GETNEXT(s2);//计算子串的nex数组
while(s1[i]==s2[i]&&i<l2&&i<len)//计算ex[0]
i++;
ex[0]=i;
po=0;//初始化po的位置
for(i=1;i<len;i++)
{
if(nex[i-po]+i<ex[po]+po)//第一种情况,直接可以得到ex[i]的值
ex[i]=nex[i-po];
else//第二种情况,要继续匹配才能得到ex[i]的值
{
j=ex[po]+po-i;
if(j<0)j=0;//如果i>ex[po]+po则要从头开始匹配
while(i+j<len&&j<l2&&s1[j+i]==s2[j])//计算ex[i]
j++;
ex[i]=j;
po=i;//更新po的位置
}
}
}
char s1[N],s2[N];
int cnt[N];
signed main(){
cin>>s1>>s2;
int len1 = strlen(s1);
int len2 = strlen(s2);
strcpy(s,s1);
reverse(s,s+len1);
pam.init();
pam.build(s);
for(int i = 0 ; i < len1 ; i ++){
cnt[i] = pam.cnt[len1-1-i];
}
reverse(s1,s1+len1);
EXKMP(s1,s2);
reverse(ex,ex+len1);
int res = 0;
for(int i = 0 ; i < len1-1 ; i ++){
res += ex[i]*cnt[i+1];
}
cout<<res<<endl;
return 0;
}