题目链接:https://vjudge.net/problem/POJ-3376
很不错的一道题,对我来说挺有难度,写起来很有感觉
先说题意,给出很多字符串,比如有n个,它们两两组合可以形成n*n个字符串,求其中有多少个是回文串
要用到扩展kmp和trie树两个东西
先把所有给出的字符串组成一个Trie树,中间要加一些数据,后面再说
再用所有字符串的转置去匹配
比如说abc和ba这两个串,那么ba的转置ab和abc的前缀是匹配的,剩下c是回文串,就可以断定abc和ba可以组成回文串abcba
还有一种情况,ab和cba,cba的转置abc的前缀和ab是匹配的,剩下c是回文串...
所以在Trie树的节点里要加一个数据,代表此节点后连接的回文串个数,这样匹配到这里结束后就额外加上这个数
或者是匹配还未结束,但你匹配的字符串后边是一个回文串,就加上已匹配的部分串的个数
判断回文就将字符串转置用扩展KMP的extend数组来做
具体看代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
typedef long long ll;
const int N=2*1e6+10; //文本串长度
const int M=2*1e6+10; //模式串长度
char a[N]; //文本串
char b[M]; //模式串
int next0[M],extend[N];
void get_next()
{
int len=strlen(b);
next0[0]=len;
int i=0;
while(i<len-1&&b[i]==b[i+1]) ++i;
next0[1]=i;
i=1;
for(int k=2;k<len;++k)
{
int p=i+next0[i]-1;
int L=next0[k-i];
if(k-1+L>=p)
{
int j=(p-k+1)>0?(p-k+1):0;
while(k+j<len&&b[k+j]==b[j]) ++j;
next0[k]=j;
i=k;
}
else next0[k]=L;
}
}
void get_extend() //extend[i]保存以a[i]为起始的文本串与模式串的最大公共前缀长度
{
int i=0,j,po,alen=strlen(a),blen=strlen(b);
get_next();
while(a[i]==b[i]&&i<alen&&i<blen) i++;
extend[0]=i;
po=0;
for(i=1;i<alen;i++)
{
if(next0[i-po]+i<extend[po]+po) extend[i]=next0[i-po];
else
{
j=extend[po]+po-i;
if(j<0) j=0;
while(i+j<alen&&j<blen&&a[j+i]==b[j]) j++;
extend[i]=j;
po=i;
}
}
}
#define idx(x) x-'a'; //小写字母
const int MAXN=2*1e6+10; //最大节点数
struct Trie
{
int next[26]; //限于小写字母或大写字母
int val; //保存之后的回文串个数
int num; //保存个数
}tree[MAXN];
int nxt,T;
int Add() //分配数组地址
{
memset(&tree[nxt],0,sizeof(Trie));
return nxt++;
}
void Insert(char *s) //插入到字典树
{
int rt=0,len=strlen(s);
for(int i=0;i<len;i++)
{
int c=idx(s[i]);
if(!tree[rt].next[c]) tree[rt].next[c]=Add();
rt=tree[rt].next[c];
if(i<len-1&&extend[i+1]+i+1==len) tree[rt].val++; //后面是回文串
}
tree[rt].num++; //数量
}
ll ans;
int Find(char *s)
{
int rt=0,len=strlen(s);
for(int i=0;i<len;i++)
{
int c=idx(s[i]);
if(!tree[rt].next[c]) return 0; //匹配失败,返回0
rt=tree[rt].next[c];
if(i<len-1&&extend[i+1]+i+1==len) ans+=tree[rt].num; //匹配串后面是回文串,加上个数
}
return tree[rt].val+tree[rt].num; //匹配成功,两者都加
}
void init() //初始化
{
memset(&tree[0],0,sizeof(Trie));
nxt=1;
}
char str[2*MAXN]; //把所有串保存进来,中间加了分隔符,所以要开大一些
int main()
{
init();
int T,n,cnt=0;
scanf("%d",&T);
while(T--)
{
scanf("%d%s",&n,a);
int len=strlen(a);
for(int i=0;i<len;i++) //转置
b[i]=a[len-i-1];
b[len]='\0';
get_extend(); //得到extend数组
Insert(a);
str[cnt++]='|';
for(int i=0;i<len;i++)
str[cnt++]=a[i];
}
ans=0;
int t=0;
for(int i=cnt-1;i>=0;i--) //反向查找,相当于转置了字符串
{
if(str[i]=='|')
{
a[t++]='\0';
int len=strlen(a);
for(int i=0;i<len;i++)
b[i]=a[len-i-1];
get_extend(); //这里再求匹配字串的extend
ans+=Find(a);
t=0;
}
else a[t++]=str[i];
}
printf("%lld\n",ans);
return 0;
}