环长为奇数直接dp,偶数用两行dp(相当于把环前一半和后一半拼起来)
(然而实际上奇数有公式QwQ)
#include<stdio.h>
#include<cstring>
#include<algorithm>
#define mod 998244353
using namespace std;
typedef long long ll;
struct matrix{int k[9][9];}a,p;
matrix operator * (matrix a,matrix b)
{
matrix rt;
memset(rt.k,0,sizeof(rt.k));
for (int i=0;i<9;i++) for (int k=0;k<9;k++) for (int j=0;j<9;j++) rt.k[i][j]=(rt.k[i][j]+(ll)a.k[i][k]*b.k[k][j]) % mod;
return rt;
}
int n,m;
int calc(int x1,int y1,int x2,int y2)
{
if ((x1==x2 || x2==y2) && x2 || y1==y2 && y1) return 0;
if (!x2 && !y2)
{
if (x1 && y1) return (ll)(m-3)*(m-2) % mod;
if (x1 || y1) return (ll)(m-3)*(m-3) % mod;
return (m>=4)?((m-4)*2+(ll)(m-4)*(m-5)+1) % mod:1;
}
if (x2 && y2) return 1;
if (x2) return m-2-(y1==0);
return m-2-(x1==0);
}
inline int pow(int a,int b)
{
int rt=1;
for (;b;b>>=1,a=(ll)a*a % mod) if (b&1) rt=(ll)rt*a % mod;
return rt;
}
int main()
{
scanf("%d%d",&n,&m);
if (n&1){printf("%d",(pow(m-1,n)-m+1+mod) % mod);return 0;}n=n/2-1;
for (int i=0;i<9;i++) for (int j=0;j<9;j++) p.k[i][j]=max(calc(i%3,i/3,j%3,j/3),0);
for (int i=0;i<9;i++) a.k[i][i]=1;
for (;n;n>>=1,p=p*p) if (n&1) a=a*p;
memset(p.k,0,sizeof(p.k));
p.k[0][7]=(ll)(m-1)*m % mod;p=p*a;
printf("%d",((ll)p.k[0][0]+p.k[0][1]+p.k[0][6]+p.k[0][7]) % mod);
}