题意:
给一个⻓度为
n
n
n的三进制串,有这样⼀个操作:在每个
2
2
2后⾯插⼊⼀个
1
1
1 ,每个
1
1
1后⾯插⼊⼀个
0
0
0,然后删掉第⼀个字符。问经过多少次操作后,该串变为空串。
思路:
考虑遍历串的每一位,根据当前已经有过的操作次数(记为
x
x
x)和当前位的情况来求解删除当前位及其产生的后续影响需要的操作次数。
假如当前位为
0
0
0,则其始终为一个
0
0
0,故操作次数
+
1
+1
+1
若当前位为
1
1
1,则前
x
x
x次操作及删除这个
1
1
1时,已经产生了(
x
+
1
x+1
x+1)个
0
0
0,故操作次数 + (
x
+
2
x + 2
x+2)
若当前位为
2
2
2,此时情况比较复杂,通过打表推公式可得出:操作次数 + (
6
∗
2
x
−
3
−
x
6*2^x - 3 - x
6∗2x−3−x)
因为答案会不断迭代,当存在多个 2 2 2时,操作次数会进行指数级别的迭代,故此时需要取模,而著名的降幂定理—拓展欧拉定理便可以用来解决该问题:
拓展欧拉定理:
a
x
m
o
d
p
=
a
x
m
o
d
φ
(
p
)
+
(
x
>
φ
(
p
)
?
φ
(
p
)
:
0
)
m
o
d
p
a^x \ mod \ p = a^{x \ mod \ \varphi(p) + (x > \varphi(p)?\varphi(p):0)} \ mod \ p
ax mod p=ax mod φ(p)+(x>φ(p)?φ(p):0) mod p
故我们可以提前设定一个变量 r e s res res,保存当前答案不取模的值,因为只要 r e s res res增长到大于 φ ( 1 e 9 + 7 ) \varphi(1e9 + 7) φ(1e9+7)之后,无论多大将对表达式产生同样的影响,故可以设定其上限为 1 e 9 + 7 1e9 + 7 1e9+7
但使用拓展欧拉定理降幂时,我们需要计算
x
m
o
d
φ
(
p
)
x \ mod \ \varphi(p)
x mod φ(p),而这个计算又涉及了:
2
x
m
o
d
φ
(
p
)
2^x \ mod \ \varphi(p)
2x mod φ(p)
故需要继续使用拓展欧拉定理,继续向下求取
x
m
o
d
φ
(
φ
(
p
)
)
x \ mod \ \varphi(\varphi(p))
x mod φ(φ(p))…
因为对一个数不停地求取欧拉函数,在 l o g ( n ) log(n) log(n)次以后将得到 1 1 1,而任何数模 1 1 1都是 0 0 0,故不再需要继续计算。
故总时间复杂度
O
(
n
l
o
g
2
(
n
)
)
O(n log^2(n))
O(nlog2(n))
此题得解。
代码:
#include<iostream>
#include<algorithm>
#include<cmath>
#include<string>
#include<cstring>
#include<cstdio>
#include<vector>
using namespace std;
typedef long long ll;
const int A = 2e5 + 10;
int tot = 29;
int M[50] = {0,1,2,4,8,16,32,64,128,256,512,1024,2048,4096,8192,16384,32768,65536,131072,262144,524288,1048576,2097152,5242880,19660800,79872000,243900800,500000002,1000000006,1000000007};
int cnt[50];
char s[A];
int Res;
ll fast_mod(ll n, ll m, int mod){
if ((mod & (mod - 1)) == 0 && m > 30) return 0;
ll res = 1;
while (m > 0) {
if (m & 1) res = res * n % mod;
n = n * n % mod;
m >>= 1;
}
return res;
}
int Fun(int num, int id) {
int mm = cnt[id-1] + (Res > M[id-1]?M[id - 1]:0);
int res = 1LL * 6 * fast_mod(2, mm, M[id]) % M[id];
res = (res - num - 3) % M[id];
if (res < 0) res += M[id];
return res;
}
int main(){
int T;
scanf("%d", &T);
while(T--){
scanf("%s",s);
int len = strlen(s);
memset(cnt, 0, sizeof(cnt));
Res = 0;
for (int i = 0; i < len ; i++) {
if (s[i] == '0') {
for (int j = 1; j <= tot; j++) {
cnt[j]++;
if (cnt[j] >= M[j]) cnt[j] -= M[j];
}
if (Res < M[tot]) Res++;
} else if (s[i] == '1') {
for (int j = 1; j <= tot; j++) {
cnt[j] = (2 * cnt[j] + 2) % M[j];
}
if (Res < M[tot]) Res = min(2 * Res + 2, M[tot]);
} else {
for (int j = tot; j >= 2; j--) {
cnt[j] = (cnt[j] + Fun(cnt[j],j)) % M[j];
}
if(Res >= 32) Res = M[tot];
else Res = min(1LL * M[tot], (6LL<<Res)- 3);
}
}
printf("%d\n", cnt[tot]);
}
return 0;
}