题目
思路
我的方法
我的思路比较简单,就是直接计算出每个前缀中,用这种方法得到的不同字符串数量。
那么前缀 i i i 可生成的字符串集合 S i S_i Si 有两类:一种是使用了第 i i i 个字符,形如 S i − 2 + c S_{i-2}+c Si−2+c,另一种不使用第 i i i 个字符,即 S i − 1 S_{i-1} Si−1 。
显然同一类内,字符串不相同,所以只需要算二者的交集,即:有多少个 S i − 1 S_{i-1} Si−1 内的字符串,是以字符 c c c 结尾,且去掉字符 c c c 之后属于 S i − 2 S_{i-2} Si−2 的?
仔细一想,如果第 i − 1 i-1 i−1 个字符不是 c c c,那么这就是铁定的,因为以 c c c 结尾的字符串整个都属于 S i − 2 S_{i-2} Si−2 了;哪怕第 i − 1 i-1 i−1 个字符是 c c c,并且 S i − 1 S_{i-1} Si−1 中的字符串使用了它,将其去掉也就变为 S i − 3 ⫅ S i − 2 S_{i-3}\subseteqq S_{i-2} Si−3⫅Si−2 了。
这就是说:对于以第 i i i 个字符结尾的 l e n > 1 len>1 len>1 的字符串,它必然是由前缀 i − 2 i-2 i−2 产生的任意字符串拼接上该字符;其余字符结尾,都可以简单地从 S i − 1 S_{i-1} Si−1 直接拷贝而来。
那么便用
f
(
i
,
c
)
f(i,c)
f(i,c) 表示前缀
i
i
i 生成的以字符
c
c
c 结尾的字符串数量。若当前字符为
c
0
c_0
c0,则
f
(
i
,
c
0
)
=
1
+
∑
c
∈
σ
f
(
i
−
2
,
c
)
f(i,c_0)=1+\sum_{c\in\sigma}f(i-2,c)
f(i,c0)=1+c∈σ∑f(i−2,c)
其中 σ \sigma σ 表示字符集。其余的则是 f ( i , c ) = f ( i − 1 , c ) ( c ≠ c 0 ) f(i,c)=f(i-1,c)\;(c\ne c_0) f(i,c)=f(i−1,c)(c=c0) 。这就是一个简单的 O ( ∣ σ ∣ n ) \mathcal O(|\sigma|n) O(∣σ∣n) 的 d p \tt dp dp 了。
看看题解,可以是 O ( n ) \mathcal O(n) O(n) 的!好,开始优化。发现每次只会修改一个位置,并且是利用 i − 2 i-2 i−2 的总和。只需存储 l a s t = ∑ c ∈ σ f ( i − 1 , c ) last=\sum_{c\in\sigma}f(i-1,c) last=∑c∈σf(i−1,c) 和 n o w = ∑ c ∈ σ f ( i , c ) now=\sum_{c\in\sigma}f(i,c) now=∑c∈σf(i,c),同时把 f ( i , c ) f(i,c) f(i,c) 这个长度为 O ( ∣ σ ∣ ) \mathcal O(|\sigma|) O(∣σ∣) 的数组存下来。这样就是 O ( 1 ) \mathcal O(1) O(1) 转移了!
官方题解
这个做法其实本质上就来源于:序列自动机。在 O I w i k i \rm OI\;wiki OIwiki 上了解它一下就行了。
于是可以考虑,有多少个子序列是刚好匹配到第
i
i
i 个状态。显然最后一个字符是字符串的第
i
i
i 个字符,并且需要不能走到更靠前的一个,记
k
(
k
<
i
)
k\;(k<i)
k(k<i) 为最靠近
i
i
i 的相同字符的位置,则
g
(
i
)
=
∑
j
=
k
−
1
i
−
2
g
(
j
)
g(i)=\sum_{j=k-1}^{i-2}g(j)
g(i)=j=k−1∑i−2g(j)
然后答案就是 ∑ i = 1 n g ( i ) \sum_{i=1}^{n}g(i) ∑i=1ng(i) 了。前缀和优化一下就是 O ( n ) \mathcal O(n) O(n) 的。
初值为 g ( 0 ) = 1 g(0)=1 g(0)=1 么?不完全是。因为我们要选择的字符不能相邻。 0 0 0 相当于一个虚点,这样会忽略掉第一个字符,所以正确的赋值是 g ( − 1 ) = 1 g(-1)=1 g(−1)=1 。实现的时候可以先赋值 g ( 1 ) = 1 g(1)=1 g(1)=1,然后 k = 0 k=0 k=0 再额外补上这玩意儿。
代码
我的实现
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
typedef long long int_;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
inline int readint(){
int a = 0; char c = getchar(), f = 1;
for(; c<'0'||c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c&&c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
const int MaxN = 200005;
const int Mod = 1e9+7;
int dp[MaxN]; char s[MaxN];
int main(){
scanf("%s",s+1);
int n = strlen(s+1);
if(n == 1) return puts("1")*0;
dp[s[1]-'a'] = 1;
int lst = 1, now = 1;
if(s[1] != s[2]){
dp[s[2]-'a'] = 1;
now = 2; // how many
}
rep(i,3,n){
int nv = now-dp[s[i]-'a'];
nv = (nv+lst+1)%Mod;
if(nv < 0) nv += Mod;
dp[s[i]-'a'] = lst+1;
lst = now, now = nv;
}
printf("%d\n",now);
return 0;
}
官方题解
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
typedef long long int_;
# define rep(i,a,b) for(int i=(a); i<=(b); ++i)
# define drep(i,a,b) for(int i=(a); i>=(b); --i)
inline int readint(){
int a = 0; char c = getchar(), f = 1;
for(; c<'0'||c>'9'; c=getchar())
if(c == '-') f = -f;
for(; '0'<=c&&c<='9'; c=getchar())
a = (a<<3)+(a<<1)+(c^48);
return a*f;
}
const int Mod = 1e9+7;
const int MaxN = 200005;
char s[MaxN];
int lst[MaxN], dp[MaxN];
int main(){
scanf("%s",s+1);
int n = strlen(s+1);
dp[1] = lst[s[1]] = 1;
rep(i,2,n){
int &p = lst[s[i]];
dp[i] = dp[i-2];
if(p >= 2)
dp[i] -= dp[p-2];
else if(!p) ++ dp[i];
dp[p = i] = (dp[i]+Mod)%Mod;
// get prefix sum
dp[i] = (dp[i]+dp[i-1])%Mod;
}
printf("%d\n",dp[n]);
return 0;
}