我:!@#¥%……&*()
原来这就是一个套路(* ̄︶ ̄)
我们从暴力入手,\(f[i+1][T][j+|T|] += f[i][S][j]\)
其中\(S,T\)表示该行哲学家的状态,而\(|T|\)表示\(T\)状态放的哲学家个数。
我们可以将第三维去掉,改成多项式的形式:\(g[i][S]=∑_{j=0}^{3*n+1}f[i][S][j]*x^j\),因为这样便于快速求出第\(n\)项(矩乘)。
这样我们可以带入\(3*n+1\)个\(x\)然后与最终结果累加的和组成\(3*n+1\)个点值,最后用\(NTT\)得到系数表达式,
而此时的\(b[m]\)显然就是最终结果了。
如何带入\(x\)快速得到\(∑g[n][S]\)?用快速幂,对于\(S\)状态到\(T\)状态,可以转移的话,矩阵填上的数即为\(x^|T|\),好巧妙啊。。。
#include <cstdio>
#include <cstring>
#include <algorithm>
#define g 3
#define mo 998244353
#define ll long long
#define mem(x, a) memset(x, a, sizeof x)
#define mpy(x, y) memcpy(x, y, sizeof y)
#define fo(x, a, b) for (int x = (a); x <= (b); x++)
#define fd(x, a, b) for (int x = (a); x >= (b); x--)
#define go(x) for (int p = tail[x], v; p; p = e[p].fr)
using namespace std;
struct matrix{int a[9][9], n, m;}aw, zy, c;
int n, m, a[4][4], dl[4], to[9][9], b[15010], r[15010];
inline int read() {
int x = 0, f = 0; char c = getchar();
while (c < '0' || c > '9') f = (c == '-') ? 1 : f, c = getchar();
while (c >= '0' && c <= '9') x = (x << 1) + (x << 3) + (c ^ 48), c = getchar();
return f ? -x : x;
}
bool judge(int x) {return ! (((x & 1) || (x & 4)) && (x & 2) && (a[2][3] || a[2][1]));}
bool check(int x, int y) {
if (! judge(x)) return 0;
if (! judge(y)) return 0;
if (x & 1) {if ((a[3][2] && (y & 1)) || (a[3][3] && (y & 2))) return 0;}
if (x & 2) {if ((a[3][1] && (y & 1)) || (a[3][2] && (y & 2)) || (a[3][3] && (y & 4))) return 0;}
if (x & 4) {if ((a[3][1] && (y & 2)) || (a[3][2] && (y & 4))) return 0;}
if (y & 1) {if ((a[1][2] && (x & 1)) || (a[1][3] && (x & 2))) return 0;}
if (y & 2) {if ((a[1][1] && (x & 1)) || (a[1][2] && (x & 2)) || (a[1][3] && (x & 4))) return 0;}
if (y & 4) {if ((a[1][1] && (x & 2)) || (a[1][2] && (x & 4))) return 0;}
return 1;
}
ll ksm(ll x, int y) {
ll s = 1;
while (y) {
if (y & 1) s = s * x % mo;
x = x * x % mo, y >>= 1;
}
return s;
}
matrix ksc(matrix a, matrix b) {
mem(c.a, 0); c.n = a.n, c.m = b.m;
fo(k, 1, a.m) fo(i, 1, a.n) fo(j, 1, b.m)
c.a[i][j] = (c.a[i][j] + (ll)a.a[i][k] * b.a[k][j]) % mo;
return c;
}
ll solve(int x) {
mem(aw.a, 0);
fo(i, 1, 8) fo(j, 1, 8)
if (to[i - 1][j - 1] == -1) zy.a[i][j] = 0;
else zy.a[i][j] = ksm(x, to[i - 1][j - 1]);
aw.a[1][1] = 1; aw.n = 1, aw.m = 8, zy.n = zy.m = 8;
int cs = n;
while (cs) {
if (cs & 1) aw = ksc(aw, zy);
zy = ksc(zy, zy), cs >>= 1;
}
ll ans = 0;
fo(i, 1, 8) (ans += aw.a[1][i]) %= mo;
return ans;
}
void NTT(int *x, int n, int type) {
fo(i, 0, n - 1) if (i < r[i]) swap(x[i], x[r[i]]);
for (int i = 1; i < n; i <<= 1) {
ll wn = ksm(g, (type * (mo - 1) / (i << 1) + mo - 1) % (mo - 1));
for (int j = 0; j < n; j += (i << 1)) {
ll w = 1;
for (int k = 0; k < i; k++, w = w * wn % mo) {
int a = x[j + k], b = w * x[j + k + i] % mo;
x[j + k] = (a + b) % mo, x[j + k + i] = (a - b + mo) % mo;
}
}
}
}
int main()
{
freopen("final.in", "r", stdin);
freopen("final.out", "w", stdout);
n = read(), m = read();
fo(i, 1, 3) fo(j, 1, 3) a[i][j] = read();
fo(i, 0, 7) fo(j, 0, 7) {
if (! check(i, j)) to[i][j] = -1;
else to[i][j] = ((j & 1) == 1) + ((j & 2) == 2) + ((j & 4) == 4);
}
int len = 1, times = 0;
while (len <= 3 * n + 1) len <<= 1, times++;
ll wn = ksm(g, (mo - 1) / len), w = 1;
fo(i, 0, len - 1) b[i] = solve(w), w = w * wn % mo;
fo(i, 0, len - 1) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (times - 1));
NTT(b, len, -1);
ll inv = ksm(len, mo - 2);
printf("%lld\n", b[m] * inv % mo);
return 0;
}