一、构造矩阵
如果只需要求
f
f
,那么就是裸的矩阵乘法问题。
建立的矩阵
A
A
,第一行全部是,并且对于所有的
1<i≤m
1
<
i
≤
m
,矩阵的第
i
i
行第列是
1
1
,剩下的个位置全部是
0
0
。
那么就等于
AS
A
S
的第
1
1
行第列。
二、问题转化
回到问题,就样例进行分析。
发现
f(1+2+3)+f(1+23)+f(12+3)+f(123)
f
(
1
+
2
+
3
)
+
f
(
1
+
23
)
+
f
(
12
+
3
)
+
f
(
123
)
可以看成是
A1+2+3+A1+23+A12+3+A123
A
1
+
2
+
3
+
A
1
+
23
+
A
12
+
3
+
A
123
的第
1
1
行第列。
又由于矩阵乘法满足结合律,所以进一步转化:
发现是 A A 的次幂之和,并且所有的指数都是原数字串的子串。
三、预处理
发现指数非常大,求的幂啃腚会T掉。
考虑预处理出
P[i][j]
P
[
i
]
[
j
]
,即:
可以递推来求,省去快速幂的 log log :
那么对于一个 k k 位的数字串,从高位到低位分别是 a1,a2,...,ak a 1 , a 2 , . . . , a k ,那么 AS A S 就是:
4、DP
已经处理了主要问题,就可以DP了!
设
dp[i]
d
p
[
i
]
表示数字串前
i
i
个数字拆分后加起来的矩阵。
就样例说,就是
A1+2+A12
A
1
+
2
+
A
12
,
dp[3]
d
p
[
3
]
就是
A1+2+3+A1+23+A12+3+A123
A
1
+
2
+
3
+
A
1
+
23
+
A
12
+
3
+
A
123
。
最后答案就是
dp[n]
d
p
[
n
]
(
n
n
为数字串长度)的第行第
1
1
列。
边界条件:为
m×m
m
×
m
的单位矩阵。
转移(
Si
S
i
为数字串第
i
i
位的值,表示数字串的子串
[l,r]
[
l
,
r
]
的值):
Anum[j+1,i] A n u m [ j + 1 , i ] 可以用预处理出的 P P 高效求得,可以将倒着循环,那么就有:
复杂度 O(n2×m3) O ( n 2 × m 3 ) 。
五、代码
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int N = 505, M = 6, V = 12, MX = 998244353;
struct cyx {
int n, m, a[M][M]; cyx() {}
cyx(int _n, int _m) : n(_n), m(_m) {memset(a, 0, sizeof(a));}
friend inline cyx operator + (cyx a, cyx b) {
int i, j; cyx res = cyx(a.n, b.m);
for (i = 1; i <= a.n; i++) for (j = 1; j <= b.m; j++)
res.a[i][j] = (a.a[i][j] + b.a[i][j]) % MX; return res;
}
friend inline cyx operator * (cyx a, cyx b) {
int i, j, k; cyx res = cyx(a.n, b.m);
for (i = 1; i <= a.n; i++) for (j = 1; j <= b.m; j++)
for (k = 1; k <= a.m; k++) res.a[i][j] = (res.a[i][j] +
1ll * a.a[i][k] * b.a[k][j] % MX) % MX; return res;
}
friend inline cyx operator ^ (cyx a, int b) {
int i; cyx res = cyx(a.n, a.m); for (i = 1; i <= a.n; i++)
res.a[i][i] = 1; while (b) {
if (b & 1) res = res * a;
a = a * a; b >>= 1;
}
return res;
}
} tr, g[N], p10[V][N];
cyx _(int n) {
int i; cyx res = cyx(n, n); for (i = 1; i <= n; i++)
res.a[i][i] = 1; return res;
}
void write(cyx a) {
int i, j; for (i = 1; i <= a.n; i++) {
for (j = 1; j <= a.m; j++) cout << a.a[i][j] << " ";
cout << endl;
}
}
char s[N]; int n, m;
int main() {
scanf("%s", s + 1); n = strlen(s + 1); cin >> m; tr = cyx(m, m);
int i, j; for (i = 1; i <= m; i++) tr.a[1][i] = 1;
for (i = 2; i <= m; i++) tr.a[i][i - 1] = 1; for (i = 0; i < 10; i++) {
p10[i][0] = tr ^ i; for (j = 1; j <= n; j++)
p10[i][j] = p10[i][j - 1] ^ 10;
}
g[0] = _(m); for (i = 1; i <= n; i++) {
g[i] = cyx(m, m); cyx trs = _(m); for (j = i; j; j--) {
trs = trs * p10[s[j] - '0'][i - j];
g[i] = g[i] + g[j - 1] * trs;
}
}
cout << g[n].a[1][1] << endl;
return 0;
}