奇数的很显然是
(m−1)n−(−1)n(m−1)
(
m
−
1
)
n
−
(
−
1
)
n
(
m
−
1
)
。
对于偶数的情况,假设先不考虑对称不同的限制,我们可以DP的时候只需要关心当前为是否与第一位相同。那么考虑到对称不同的限制,我们可以两个两个填(也就是
i
i
与一起填),那么我们只要关心当前的两个位置和
1,n2+1
1
,
n
2
+
1
两个位置的异同情况,把转移矩阵搞出来快速幂即可。
具体地,设第
1
1
位填了,第
2
2
位填了,那么当前填两位一共有
7
7
种状态如下:
1:(≠,a)
1
:
(
≠
,
a
)
2:(≠,b)
2
:
(
≠
,
b
)
3:(a,≠)
3
:
(
a
,
≠
)
4:(b,≠)
4
:
(
b
,
≠
)
5:(a,b)
5
:
(
a
,
b
)
6:(b,a)
6
:
(
b
,
a
)
其中
≠
≠
表示
≠a
≠
a
且
≠b
≠
b
,那么转移矩阵就是:
其中第
i
i
行第列(从
0
0
开始)就是上一位状态转移到当前位
j
j
的系数。答案最后乘上(
a,b
a
,
b
的选取方案)即可。
代码:
#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
#define up(x,y) (x=(x+(y))%mod)
using namespace std;
const int mod=998244353;
ll n,m;
ll ksm(ll a,ll b)
{
ll r=1;
for(;b;b>>=1,a=a*a%mod)
if(b&1) r=r*a%mod;
return r;
}
struct matrix
{
int h,w;
ll a[7][7];
matrix(int xh,int xw){h=xh;w=xw;memset(a,0,sizeof(a));}
matrix operator *(matrix b)
{
matrix re(h,b.w);
for(int i=0;i<h;i++)
for(int j=0;j<b.w;j++)
for(int k=0;k<w;k++)
up(re.a[i][j],a[i][k]*b.a[k][j]);
return re;
}
void init()
{
a[0][0]=((m-3)*(m-3)-(m-4))%mod;
a[0][1]=m-3;
a[0][2]=m-3;
a[0][3]=m-3;
a[0][4]=m-3;
a[0][5]=1;
a[0][6]=1;
a[1][0]=((m-3)*(m-2)-(m-3))%mod;
a[1][1]=0;
a[1][2]=m-3;
a[1][3]=m-2;
a[1][4]=m-2;
a[1][5]=1;
a[1][6]=0;
a[2][0]=((m-3)*(m-2)-(m-3))%mod;
a[2][1]=m-3;
a[2][2]=0;
a[2][3]=m-2;
a[2][4]=m-2;
a[2][5]=0;
a[2][6]=1;
a[3][0]=((m-3)*(m-2)-(m-3))%mod;
a[3][1]=m-2;
a[3][2]=m-2;
a[3][3]=0;
a[3][4]=m-3;
a[3][5]=0;
a[3][6]=1;
a[4][0]=((m-3)*(m-2)-(m-3))%mod;
a[4][1]=m-2;
a[4][2]=m-2;
a[4][3]=m-3;
a[4][4]=0;
a[4][5]=1;
a[4][6]=0;
a[5][0]=((m-2)*(m-2)-(m-2))%mod;
a[5][1]=m-2;
a[5][2]=0;
a[5][3]=0;
a[5][4]=m-2;
a[5][5]=0;
a[5][6]=1;
a[6][0]=((m-2)*(m-2)-(m-2))%mod;
a[6][1]=0;
a[6][2]=m-2;
a[6][3]=m-2;
a[6][4]=0;
a[6][5]=1;
a[6][6]=0;
}
}T(7,7);
matrix matksm(matrix a,ll b)
{
matrix r(a.h,a.w);
for(int i=0;i<r.h;i++)
r.a[i][i]=1;
for(;b;b>>=1,a=a*a)
if(b&1) r=r*a;
return r;
}
int main()
{
scanf("%lld%lld",&n,&m);
if(n&1) {printf("%lld",(ksm(m-1,n)+(m-1)*((n&1)?-1:1)+mod)%mod);return 0;}
T.init();
T=matksm(T,(n>>1)-1);
ll ans=(T.a[5][0]+T.a[5][2]+T.a[5][3]+T.a[5][5])%mod;
printf("%lld",ans*(m-1)%mod*m%mod);
return 0;
}