若环的长度为奇数,也就是没有正对着的点,那么我们可以用类似dp的方法F[i][j]表示第i位颜色为j的方案数
很容易可以推出公式 ans = (m-1) ^ n - (m-1)
至于环的长度为偶数,要考虑正对着点
考虑将环从长度一半处切下来,和剩下一半拼成一条宽度为2的环
每个点不能和它相邻的点相同
用F[i][j][k]表示第i位,上面颜色为j,下面颜色为k的方案数
颜色可以简化成为与第一列上面的数相同,与第一列下面的数相同,或者与这两个都不相同
那么颜色状态简化成3,每个位置的状态为3*3种
考虑矩阵快速幂,9种状态转移到另外9种状态,需要一个9*9的矩阵
分类讨论一下求这81个系数即可
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cassert>
#define mod 998244353
using namespace std;
typedef long long LL;
LL n,m;
struct Matrix{LL d[9][9];}ha,A,id;
Matrix operator *(Matrix p1,Matrix p2) {
Matrix r = id;
for (int i=0;i<9;i++)
for (int j=0;j<9;j++)
for (int k=0;k<9;k++)
r.d[i][j] += p1.d[i][k] * p2.d[k][j] , r.d[i][j] %= mod;
return r;
}
Matrix mqp(Matrix a,LL b) {
if (b == 1) return a;
if (b % 2 == 0) {
Matrix tmp = mqp(a,b/2);
return tmp * tmp;
} else {
Matrix tmp = mqp(a,b-1);
return tmp * a;
}
}
inline LL qpow(LL a,LL b) {
LL r = 1;
while (b) {
if (b&1) r = ( r * a ) % mod;
b >>= 1, a = ( a * a ) % mod;
}
return r;
}
LL trans(int x1,int y1,int x2,int y2) {
if (( x1 == x2 && x1 > 0 ) || ( y1 == y2 && y1 > 0 )) return 0LL;
if (( x1 == y1 && x1 > 0 ) || ( x2 == y2 && x2 > 0 )) return 0LL;
if (!x2 && !y2) {
if ( x1 && y1) return (m-2) * (m-3) % mod;
if ( x1 || y1) return (m-3) * (m-3) % mod;
return m>=4 ? ( (m-4) * (m-4) + (m-3) ) % mod : 1;
}
if (x2 > 0 && y2 > 0) return 1LL;
if (x2 > 0 && y2 == 0) return m-2-(y1==0);
if (y2 > 0 && x2 == 0) return m-2-(x1==0);
assert(0);
}
int main() {
cin >> n >> m;
if (n&1LL) {
LL ans = ( qpow(m-1,n) - (m-1) + mod ) % mod;
cout << ans << endl; return 0;
}
if (m == 1) { puts("0"); return 0; }
for (int i=0;i<9;i++)
for (int j=0;j<9;j++)
ha.d[i][j] = max( trans(i%3,i/3,j%3,j/3) , 0LL);
for (int i=0;i<9;i++) A.d[i][i] = 1LL;
// for (int i=0;i<9;i++) {
// for (int j=0;j<9;j++) printf("%d ",ha.d[i][j]);
// printf("\n");
// }
Matrix dhr = mqp(ha,n/2-1);
A = A * dhr;
memset(ha.d,0,sizeof(ha.d));
ha.d[0][7] = (m-1) * m % mod; ha = ha * A;
// for (int i=0;i<9;i++) {
// for (int j=0;j<9;j++) printf("%d ",ha.d[i][j]);
// printf("\n");
// }
LL ans = ( ha.d[0][0] + ha.d[0][1] + ha.d[0][6] + ha.d[0][7] + mod) % mod;
cout << ans << endl;
return 0;
}