ISIJ 2018 奇怪的字符串(Training Round D6T1)
题目名称:奇怪的字符串
文件名称:strange.in / strange.out
题目描述
考虑字符串 s 仅由小写字母组成,例如 “abba”。定义 W(s) 为 s 所有本质不同的连续子串的集合,例如 W(“abba”) = { “a”,”b”,”ab”,”ba”,”bb”,”abb”,”bba”,”abba” }。定义 Y(s) 为 s 所有本质不同的非连续子串的集合,例如 Y(“abba”) = W(“abba”) ∪ { “aa”,”aba” },显然 W(s) 是 Y(s) 的子集。
对于一些奇怪的字符串 s 满足 W(s) = Y(s),例如 “abba” 就不是奇怪的,但 “abb” 是奇怪的因为 W(s) = Y(s) = { “a”,”b”,”ab”,”bb”,”abb” }。现在小明有一个字符串 s,请你求出 W(s) 中有多少个字符串是奇怪的?
注意:集合中的所有元素互不相同
限制
1s 256M
1≤|s|≤ 200,000
输入格式
一个字符串 s
输出格式
一个整数,表示 W(s) 中有多少个字符串是奇怪的
输入样例
abba
输出样例
7
样例解释
abba 的所有连续子串中,除了 abba 以外都是 ” 奇怪的 “
分析
呃,画个图体会一下:
如果我想要证明,区间[l,r]不是奇怪的,我们可以证明[l,r]有一个非连续的新子串。
如上图,你刚扫到了x点,x的左边是a,右边是b。如果在x左边,有一个非a(并不是全是a)的串,那么将它和x接起来,你就可以得到一个在连续的情况下死都得不到的新序列。右边同理。
将结论推广,我们可以得到,如果区间不是奇怪的,那么这个区间从左往右扫至少有三段字符串(规定每串字符串中字符都相同)。反推之,如果这个区间至多只有两段字符串,则区间一定是奇怪的。(必要且充分)
这时我们友好的发现,奇怪的子串只有两种,aa型或者ab型,注意字符串集合中元素不会重复出现。首先aa型非常好做,分别统计从’a’到’z’每种字符串的最长,在加起来就得到了aa型的所有子串。
for(long long i=1;i<=n;){
long long j=i;
while(j<n&&s[j]==s[j+1])j++;
length[s[i]-'a']=max(length[s[i]-'a'],j-i+1);
i=j+1;
}
for(int i=0;i<26;i++)
ans+=length[i];
ab型由于考虑到重复统计,且有两维,不能直接只取最大值。【例子(1,3)(3,1)(2,2)】我们考虑先开一个vector将所有二元组(局部最大)先存起来。
struct node{
long long l,r;
};
vector<node>a[26][26];//两维的vector
vector<node> ::iterator iv;//懒人象征(迭代器)
long long i=1,j,ml,mr,l,r;
ml=s[1],l=1;
while(s[i]==s[i+1])i++,l++;
while(i<n){
j=i+1;mr=s[j];r=1;
while(j<n&&s[j]==s[j+1]) j++,r++;
a[ml-'a'][mr-'a'].push_back((node){l,r});
i=j;l=r;ml=mr;
}
考虑二元组的计算,我们可以把它们看成一个个矩形,用扫描线求矩形面积并。
鉴于扫描线的代码量,我们开始思考优化方案,首先的我们想到,所有的横坐标都是从1到”长度“。可以用一个平衡树代替线段树,每次插入或删除一个值。查找出最大值,即为当前线段的长。平衡树可以用STL的set
上述优化可以减少代码量但不能优化时间,我们再介绍中更强的优化,简洁好写而且常数比前两种小很多。
bool cmp(node x,node y){
return x.l>y.l;
}
void calc(int x,int y){
sort(a[x][y].begin(),a[x][y].end(),cmp);
iv=a[x][y].begin();
long long maxx=(*iv).l,maxy=(*iv).r;
ans+=maxx*maxy,iv++;
for(;iv!=a[x][y].end();iv++){
if(maxy<(*iv).r){
ans+=(*iv).l*((*iv).r-maxy);
maxy=(*iv).r;
}
}
}
可否有些头绪?我们先按左边区间的长度排序,使得一侧长度单调递减,再判断右侧区间长度如果小于之前的最大长度。则一定被之前的某个区间包含过,不需要再计算。长度更长,则将超过最大长度的部分乘上自身的左侧区间,更新最大值。
完整代码如下:
#include<bits/stdc++.h>
using namespace std;
const int maxn=201000;
struct node{
long long l,r;
};
vector<node>a[26][26];
vector<node> ::iterator iv;
long long n;
long long length[26];
char s[maxn];
long long ans=0;
void init(){
long long i=1,j,ml,mr,l,r;
ml=s[1],l=1;
while(s[i]==s[i+1])i++,l++;
while(i<n){
j=i+1;mr=s[j];r=1;
while(j<n&&s[j]==s[j+1]) j++,r++;
a[ml-'a'][mr-'a'].push_back((node){l,r});
i=j;l=r;ml=mr;
}
return ;
}
bool cmp(node x,node y){
return x.l>y.l;
}
void calc(int x,int y){
sort(a[x][y].begin(),a[x][y].end(),cmp);
iv=a[x][y].begin();
long long maxx=(*iv).l,maxy=(*iv).r;
ans+=maxx*maxy,iv++;
for(;iv!=a[x][y].end();iv++){
if(maxy<(*iv).r){
ans+=(*iv).l*((*iv).r-maxy);
maxy=(*iv).r;
}
}
}
int main(){
memset(length,0,sizeof(length));
scanf("%s",s+1);
n=strlen(s+1);
for(long long i=1;i<=n;){
long long j=i;
while(j<n&&s[j]==s[j+1])j++;
length[s[i]-'a']=max(length[s[i]-'a'],j-i+1);
i=j+1;
}
for(int i=0;i<26;i++)
ans+=length[i];
init();
for(int i=0;i<26;i++)
for(int j=0;j<26;j++)
if(a[i][j].size()>0)
calc(i,j);
printf("%lld\n",ans);
return 0;
}