3415. Palindrome
单点时限: 1.0 sec
内存限制: 512 MB
Alice like strings, especially long strings. For each string, she has a special evaluation system to judge how elegant the string is. She defines that a string S[1…3n−2] is one-and-half palindromic if and only if it satisfies S[i]=S[2n−i]=S[2n+i−2] (1≤i≤n). For example, abcbabc
is one-and-half palindromic string, and abccbaabc
is not. Now, Alice has generated some long strings. She ask for your help to find how many substrings which is one-and-half palindromic.
输入格式
There is only one line containing a string (the length of string is less than or equal to 500 000). The string only consists of lowercase letters.
输出格式
Output an integer denoting the number of one-and-half palindromic substrings.
样例
input
abcbabc
output
1
input
abccbaabc
output
0
input
ababcbabccbaabc
output
2
复盘: 题意多看两遍也看明白了,寻找两个回文半径能互相覆盖的回文中心。然后我写下了初版代码,然后超时了…于是通过搜索,看到了fnq9999的代码,首先是三个关键词:马拉车(manacher
)、主席树(可持久化线段树)、树状数组。通过搜索学习了一下manacher
和树状数组。
PS: 有错误欢迎指正!
1、manacher
先说说manacher
。manacher
可以实现O(n)时间复杂度的计算每个字符的回文半径。实现过程大致分两步。
1、对字符串进行处理,举个例子,将abba
处理为$#a#b#b#a#
。
其中$
是为了防止下面manacher
函数中while(initStr[i-len[i]]==initStr[i+len[i]])len[i]++;
发生越界。
加#
是为了统一整个字符串长度的奇偶性,保证在计算某一字符串最长回文长度时不会出错。例如abba
是回文字符串,但是没有某一个字符是回文中心。对于本题来说,我觉得似乎没有这个问题,因为我们必须要找到每个字符的回文半径。
string init(string s){
string tmp;
tmp.resize(s.size()*2+2);
tmp[0]='$';
for(int i=0;i<s.size();i++){
tmp[i*2+1]='#';
tmp[i*2+2]=s[i];
}
tmp[tmp.size()-1]='#';
return tmp;
}
2、处理字符串的回文半径
initStr | … | c | # | b | # | d | # | b | # | a | … |
---|---|---|---|---|---|---|---|---|---|---|---|
索引 | … | 2 | 3 | 4 | 5 | id=6 | 7 | 8 | 9 | mx=10 | … |
vector<int> manacher(string s){
string initStr=init(s);//对字符串预处理,类似于将abba处理为$#a#b#b#a#
int n=s.size(),m=initStr.size();
vector<int> ans(n),len(m);//ans是最后用来保存s的每个字符的回文半径,len保存处理后的字符initStr的每个字符的回文半径
int id,mx=0;//mx记录最新的回文串右边界,id保存mx回文串右边界对应的回文串中心,类似于上表格
for(int i=1;i<m;i++){
if(i<mx)len[i]=min(mx-i,len[id-(i-id)]);//核心。len[id*2-i]=len[id-(i-id)]就是i关于id对称的点的回文半径
//由于mx及mx右边的点我们还没有检查过,所以根据已知信息能取的半径不能超过mx-i
else len[i]=1;//当i>=mx没有任何已知信息,初始化为1
while(initStr[i-len[i]]==initStr[i+len[i]])len[i]++;//进一步检查当前i位置字符的回文范围
if(mx<i+len[i]){//当i位置的回文右边界超出之前的边界就更新id和mx
id=i;
mx=i+len[i];
}
}
for(int i=1;i<m;i++)if(initStr[i]!='#')ans[i/2-1]=len[i]/2;//获取我们需要的关于s每个字符的回文半径
return ans;
}
一些详细细节可以看一下沉~杉的这篇文章。
2、树状数组
树状数组实现的是O(logn)
时间复杂度的更新前缀和(add函数)与查询前缀和(sum函数)。
树状数组形态大致如下,*为虚拟节点,实际并不存在。每个数的二进制末尾0的个数越多,离根节点越近。
(1000)8 | |||||||
---|---|---|---|---|---|---|---|
(100)4 | * | ||||||
(10)2 | * | (110)6 | * | ||||
(1)1 | * | (11)3 | * | (101)5 | * | (111)7 | * |
索引 | (1)1 | (10)2 | (11)3 | (100)4 | (101)5 | (110)6 | (111)7 | (1000)8 |
---|---|---|---|---|---|---|---|---|
原数组 | a | b | c | d | e | f | g | h |
树状数组c | a | a+b | c | a+b+c+d | e | a+b+c+d+e+f | g | a+b+c+d+e+f+g+h |
//求取二进制最后一位1的大小,举个例子,x=110(6),返回值为10(2);x=101(5),返回值为1(1)
//x!=0时,-x和x互为相反数,不如设x为正数,-x机器数用补码表示是对x取反,再加一,&操作之后,高位全为零,
//仅留最后一位1和相同数量的末尾0。就以x=6(四位,最高位为符号位)为例x=0110,-x原码取反:1001,加一即反码:1010。
//高位由于取反&的时候全为0,x末尾有多少个连续的0,取反就有多少个连续的1,再加一得反码就又会得到相同数量的末尾0以及进位的一个1
int lowbit(int x){
return x&(-x);
}
long sum(vector<long> c,int x){//c即树状数组。查询x位的前缀和
long ans=0;
while(x>0){
ans+=c[x];
x-=lowbit(x);//找到前一个分支
}
return ans;
}
void add(vector<long> &c,int x,int d){//逻辑上对索引为x的数加上d。更新被影响到的所有前缀和
while(x<c.size()){
c[x]+=d;
x+=lowbit(x);//找到父节点,更新父节点
}
}
最后来看一下这题fnq9999的代码(注释了部分代码,因为输入的方式问题)如何使用树状数组吧。一开始我也没看懂后面的操作,找个例子debug多看两遍会方便理解一点
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
using namespace std;
const int maxn = 1e6 + 10;
typedef long long ll;
int T;
char t[maxn], s[maxn];
int n, p[maxn];
void init() {//manacher求回文串前对字符串预处理
int len = strlen(t);
s[0] = '$'; s[1] = '#';
n = 2;
for (int i = 0; i < len; i++) {
s[n++] = t[i]; s[n++] = '#';
}
s[n] = 0;
}
void manacher() {//求回文半径,存入p数组
int id, mx = 0;
for (int i = 1; i < n; i++) {
if (i < mx) p[i] = min(p[2*id-i], mx-i);
else p[i] = 1;
while (s[i-p[i]] == s[i+p[i]]) p[i]++;
if (mx < i + p[i]) { id = i; mx = i + p[i]; }
}
int m = n;
n = 0;
for (int i = 1; i < m; i++) if (s[i] != '#') p[++n] = p[i]/2 ;
}
ll C[maxn];//树状数组
//lowbit,sum(查询前缀和),add(在对位置x的数加d后,更新被影响的前缀和)树状数组三件套
int lowbit(int x) { return x & (-x); }
ll sum(int x) {
ll ans = 0;
while (x > 0) { ans += C[x]; x -= lowbit(x); }
return ans;
}
void add(int x, int d) {
while (x <= n) { C[x] += d; x += lowbit(x); }
}
vector<int> v[maxn];
int main() {
// #ifdef swt
// freopen("input2.txt","r",stdin);
// #endif // swt
// #define en '\n'
// scanf("%d", &T);
// while (T--) {
scanf("%s", t);
init();
manacher();
for (int i = 1; i <= n; i++) v[i].clear();//初始化
for (int i = 1; i <= n; i++) v[i-p[i]+1].push_back(i);//从当前位置向前看,在最左边的位置存放当前回文中心索引
for (int i = 1; i <= n; i++) C[i] = 0;//初始化树状数组
ll ans = 0;
for (int i = 1; i <= n; i++) {
for (int j = 0; j < v[i].size(); j++) add(v[i][j] ,1);//对每个新出现的回文中心索引对应的树状数组的原数组值+1
ans += (sum(i + p[i]-1) - sum(i));//获取当前位置之后与当前回文右边界之前有多少个回文中心
}
printf("%lld\n", ans);
// }
return 0;
}
初版:
#include<iostream>
#include<vector>
#include<cmath>
using namespace std;
int main(){
string s;
getline(cin,s);
long ans=0;
int n=s.size();
vector<int> maxpa(n,0),maxN;
for(int i=0;i<n;i++){
int left=i-1,right=i+1,tmp=0;
while(left>=0&&right<n&&s[left--]==s[right++])tmp++;
maxpa[i]=tmp;
}
for(int i=0;i<n;i++){
for(int j=i+1;j<n;j++){
if(maxpa[i]>=j-i&&maxpa[j]>=j-i)ans++;
}
}
cout<<ans<<endl;
}
下面是我使用manacher
与树状数组的思想改的代码,但是依旧超时,应该在一些变量声明上存在问题导致超时。欢迎大家指出错误!
#include<iostream>
#include<vector>
#include<cmath>
using namespace std;
string init(string s){
string tmp;
tmp.resize(s.size()*2+2);
tmp[0]='$';
for(int i=0;i<s.size();i++){
tmp[i*2+1]='#';
tmp[i*2+2]=s[i];
}
tmp[tmp.size()-1]='#';
return tmp;
}
vector<int> manacher(string s){
string initStr=init(s);
int n=s.size(),m=initStr.size();
vector<int> ans(n),len(m);
int id,mx=0;
for(int i=1;i<m;i++){
if(i<mx)len[i]=min(mx-i,len[id-(i-id)]);
else len[i]=1;
while(initStr[i-len[i]]==initStr[i+len[i]])len[i]++;
if(mx<i+len[i]){
id=i;
mx=i+len[i];
}
}
for(int i=1;i<m;i++)if(initStr[i]!='#')ans[i/2-1]=len[i]/2;
return ans;
}
int lowbit(int x){
return x&(-x);
}
long sum(vector<long> c,int x){
long ans=0;
while(x>0){
ans+=c[x];
x-=lowbit(x);
}
return ans;
}
void add(vector<long> &c,int x,int d){
while(x<c.size()){
c[x]+=d;
x+=lowbit(x);
}
}
int main(){
string s;
getline(cin,s);
long ans=0;
int n=s.size();
vector<int> maxpa=manacher(s);
vector<vector<int>> v(n+1,vector<int>());
vector<long> c(n+1);
for(int i=1;i<=n;i++)v[i-maxpa[i-1]+1].push_back(i);
for(int i=1;i<=n;i++){
for(int j=0;j<v[i].size();j++){
add(c,v[i][j],1);
}
ans+=(sum(c,i+maxpa[i-1]-1)-sum(c,i));
}
cout<<ans<<endl;
}