通过观察可以发现两位数的组合是有限的,所以两位数可以如同一位数一样写出dp方程:
dp[i + 1][6] = dp[i][6] + dp[i][4];
dp[i + 1][4] = dp[i][2] + dp[i][6];
dp[i + 1][2] = dp[i][1];
dp[i + 1][1] = dp[i][4];
dp[i + 1][61] = dp[i][44];
dp[i + 1][62] = dp[i][41];
dp[i + 1][64] = dp[i][6] + dp[i][42];
dp[i + 1][66] = dp[i][46];
dp[i + 1][41] = dp[i][64];
dp[i + 1][42] = dp[i][61];
dp[i + 1][44] = dp[i][62];
dp[i + 1][46] = dp[i][26] + dp[i][66];
dp[i + 1][26] = dp[i][16];
dp[i + 1][16] = dp[i][4];
为了节省内存,我们需要对这14个数进行编号:
id | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
val | 1 | 2 | 62 | 4 | 61 | 6 | 64 | 66 | 41 | 42 | 44 | 46 | 26 | 16 |
替换之后的dp方程:
dp[i + 1][6] = dp[i][6] + dp[i][4];
dp[i + 1][4] = dp[i][2] + dp[i][6];
dp[i + 1][2] = dp[i][1];
dp[i + 1][1] = dp[i][4];
dp[i + 1][5] = dp[i][11];
dp[i + 1][3] = dp[i][9];
dp[i + 1][7] = dp[i][6] + dp[i][10];
dp[i + 1][8] = dp[i][12];
dp[i + 1][9] = dp[i][7];
dp[i + 1][10] = dp[i][5];
dp[i + 1][11] = dp[i][3];
dp[i + 1][12] = dp[i][13] + dp[i][8];
dp[i + 1][13] = dp[i][14];
dp[i + 1][14] = dp[i][4];
然后就有了52分的版本:
#include <iostream>
#include <stdio.h>
#include <map>
#define ll long long
using namespace std;
ll dp[500005][16];
const int mod = 998244353;
map<string, int>mp;
int main() {
mp["1"] = 1; mp["2"] = 2; mp["4"] = 4; mp["6"] = 6;
mp["16"] = 14; mp["26"] = 13; mp["41"] = 9; mp["42"] = 10; mp["44"] = 11;
mp["46"] = 12; mp["61"] = 5; mp["62"] = 3; mp["64"] = 7; mp["66"] = 8;
int n;
cin >> n;
dp[1][2] = 1;
for (int i = 1; i < n; i++) {
dp[i + 1][6] = dp[i][6] + dp[i][4];
dp[i + 1][4] = dp[i][2] + dp[i][6];
dp[i + 1][2] = dp[i][1];
dp[i + 1][1] = dp[i][4];
dp[i + 1][5] = dp[i][11];
dp[i + 1][3] = dp[i][9];
dp[i + 1][7] = dp[i][6] + dp[i][10];
dp[i + 1][8] = dp[i][12];
dp[i + 1][9] = dp[i][7];
dp[i + 1][10] = dp[i][5];
dp[i + 1][11] = dp[i][3];
dp[i + 1][12] = dp[i][13] + dp[i][8];
dp[i + 1][13] = dp[i][14];
dp[i + 1][14] = dp[i][4];
for (int j = 1; j <= 14; j++)
dp[i + 1][j] %= mod;
}
string s;
cin >> s;
int idx = mp[s];
cout << dp[n][idx] << endl;
}
因为数据量到了1e9,所以需要用到矩阵来进行优化:
(矩阵行为dp[i+1],列为dp[i],可用矩阵乘法自己算一下)
矩阵一定要写对,不然debug时会疯掉!!!
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4
1 0,1,0,0,0,0,0,0,0,0,0,0,0,0,
2 0,0,0,1,0,0,0,0,0,0,0,0,0,0,
3 0,0,0,0,0,0,0,0,0,0,1,0,0,0,
4 1,0,0,0,0,1,0,0,0,0,0,0,0,1,
5 0,0,0,0,0,0,0,0,0,1,0,0,0,0,
6 0,0,0,1,0,1,1,0,0,0,0,0,0,0,
7 0,0,0,0,0,0,0,0,1,0,0,0,0,0,
8 0,0,0,0,0,0,0,0,0,0,0,1,0,0,
9 0,0,1,0,0,0,0,0,0,0,0,0,0,0,
0 0,0,0,0,0,0,1,0,0,0,0,0,0,0,
1 0,0,0,0,1,0,0,0,0,0,0,0,0,0,
2 0,0,0,0,0,0,0,1,0,0,0,0,0,0,
3 0,0,0,0,0,0,0,0,0,0,0,1,0,0,
4 0,0,0,0,0,0,0,0,0,0,0,0,1,0
之后我们就得到了96分的代码:
#include <cstring>
#include <string.h>
#include <iostream>
#include <stdio.h>
#include <unordered_map>
#define LL long long
using namespace std;
const int mod = 998244353;
unordered_map<string, int>mp;
void mul(LL c[][14], LL a[][14], LL b[][14]) {
static LL tmp[14][14];
memset(tmp, 0, sizeof tmp);
for (int i = 0; i < 14; ++i)
for (int j = 0; j < 14; ++j)
for (int k = 0; k < 14; ++k)
tmp[i][j] = (tmp[i][j] + a[i][k] * b[k][j]) % mod;
memcpy(c, tmp, sizeof tmp);
}
int qmi(string s, int n) {
if (n < 0) return 0;
LL f[14][14] = { 1 };
LL a[14][14] =
{
0,1,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,1,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,1,0,0,0,
1,0,0,0,0,1,0,0,0,0,0,0,0,1,
0,0,0,0,0,0,0,0,0,1,0,0,0,0,
0,0,0,1,0,1,1,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,1,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,1,0,0,
0,0,1,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,1,0,0,0,0,0,0,0,
0,0,0,0,1,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,1,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,1,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,1,0
};
while (n) {
if (n & 1) mul(f, f, a);
mul(a, a, a);
n >>= 1;
}
return (int)f[0][mp[s]-1];
}
int main() {
mp["1"] = 1; mp["2"] = 2; mp["4"] = 4; mp["6"] = 6;
mp["16"] = 14; mp["26"] = 13; mp["41"] = 9; mp["42"] = 10; mp["44"] = 11;
mp["46"] = 12; mp["61"] = 5; mp["62"] = 3; mp["64"] = 7; mp["66"] = 8;
int n, ans;
string s;
cin >> n >> s;
if (s.size() == 1 || s.size() == 2) ans = qmi(s, n) % mod;
cout << ans;
return 0;
}
由于最后一个测试数据s的取值范围在1e5 ,所以我们进行回溯
任何一段数字最后都能回溯成两位数或一位数
例如 : 6264(n) : 因为初始字段(2,4,16,64)没有以62开头的,所以我们要在前面添上1,变成16264,而16264向上一层回溯的结果为416(n-1),在向上一层回溯结果为24(n-2),变成两位数即可利用上边的思路进行求解.
而464则会有两种回溯策略,一为直接回溯为26(n-1),二为在首位添加6,变成6464,回溯为66(n-1),而最终结果就为<n-1,26>&<n-1,66>两种查找结果之和,所以利用bfs的思路进行求解.
最后,满分代码(本蒟蒻哭了,):
#include <cstring>
#include <string.h>
#include <iostream>
#include <stdio.h>
#include <unordered_map>
#define LL long long
using namespace std;
const int mod = 998244353;
unordered_map<string, int>mp;
#define TLE ios::sync_with_stdio(0),cin.tie(0)
const int INF = 0x3f3f3f3f;
pair<string, int> q[1000005]; int ans = 0;
void mul(LL c[][14], LL a[][14], LL b[][14]) { //矩阵乘法
static LL tmp[14][14];
memset(tmp, 0, sizeof tmp);
for (int i = 0; i < 14; ++i)
for (int j = 0; j < 14; ++j)
for (int k = 0; k < 14; ++k)
tmp[i][j] = (tmp[i][j] + a[i][k] * b[k][j]) % mod;
memcpy(c, tmp, sizeof tmp);
}
int qmi(string s, int n) {
if (n < 0) return 0;
LL f[14][14] = { 1 };
LL a[14][14] =
{
0,1,0,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,1,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,1,0,0,0,
1,0,0,0,0,1,0,0,0,0,0,0,0,1,
0,0,0,0,0,0,0,0,0,1,0,0,0,0,
0,0,0,1,0,1,1,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,1,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,1,0,0,
0,0,1,0,0,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,1,0,0,0,0,0,0,0,
0,0,0,0,1,0,0,0,0,0,0,0,0,0,
0,0,0,0,0,0,0,1,0,0,0,0,0,0,
0,0,0,0,0,0,0,0,0,0,0,1,0,0,
0,0,0,0,0,0,0,0,0,0,0,0,1,0
};
while (n) { //快速幂
if (n & 1) mul(f, f, a);
mul(a, a, a);
n >>= 1;
}
return (int)f[0][mp[s] - 1];
}
string backtrack(string ch) { //回溯
string upper = "";
for (int i = 0; i < ch.size(); ++i) {
if (ch[i] == '2') upper += '1';
else if (ch[i] == '1' && (i == ch.size() - 1 || ch[i + 1] == '6')) upper += '4', i++;
else if (ch[i] == '6' && (i == ch.size() - 1 || ch[i + 1] == '4')) upper += '6', i++;
else if (ch[i] == '4') upper += '2';
else return "";
}
return upper;
}
void bfs(string start, int n) {
int step = 0, t = 0;
q[t] = { start, n };
while (step <= t) {
pair<string, int> u = q[step++];
string s = u.first;
int depth = u.second;
if (s == "" || depth < 0) continue;
if (s.size() == 1 || s.size() == 2) ans = (ans + qmi(s, depth)) % mod;
else {
q[++t] = { backtrack(s),depth - 1 }; // 开头: 1 -> 4 ,6 -> 6, 4 -> 2, 2 -> 1
if (s[0] == '4') q[++t] = { backtrack("6" + s), depth - 1 }; // 开头: 4 -> 6
else if (s[0] == '6') q[++t] = { backtrack("1" + s), depth - 1 }; //开头: 6 -> 4
}
}
}
int main() {
TLE;
mp["1"] = 1; mp["2"] = 2; mp["4"] = 4; mp["6"] = 6;
mp["16"] = 14; mp["26"] = 13; mp["41"] = 9; mp["42"] = 10; mp["44"] = 11;
mp["46"] = 12; mp["61"] = 5; mp["62"] = 3; mp["64"] = 7; mp["66"] = 8;
int n;
string s;
cin >> n >> s;
bfs(s, n);
cout << ans << endl;
return 0;
}