[ABC307E] Distinct Adjacent
题面翻译
给定一个长度为 n n n 的环,每个位置可以填 1 ∼ m 1\sim m 1∼m的颜色,求有多少种方案,满足相邻位置颜色不相同。
translated by @liangbowen.
题目描述
$ 1 $ から $ N $ の番号がついた $ N $ 人の人が輪になってならんでいます。人 $ 1 $ の右隣には人 $ 2 $ が、人 $ 2 $ の右隣には人 $ 3 $ が、……、人 $ N $ の右隣には人 $ 1 $ がいます。
$ N $ 人の人にそれぞれ $ 0 $ 以上 $ M $ 未満の整数を $ 1 $ つずつ渡します。
$ M^N $ 通りの渡し方のうち、どの隣り合う $ 2 $ 人が渡された数も異なるものの数を、$ 998244353 $ で割ったあまりを求めてください。
输入格式
入力は以下の形式で標準入力から与えられる。
$ N $ $ M $
输出格式
答えを出力せよ。
样例 #1
样例输入 #1
3 3
样例输出 #1
6
样例 #2
样例输入 #2
4 2
样例输出 #2
2
样例 #3
样例输入 #3
987654 456789
样例输出 #3
778634319
提示
制約
- $ 2\ \leq\ N,M\ \leq\ 10^6 $
- $ N,M $ は整数である
Sample Explanation 1
人 $ 1,2,3 $ に渡す整数がそれぞれ $ (0,1,2),(0,2,1),(1,0,2),(1,2,0),(2,0,1),(2,1,0) $ のときの $ 6 $ 通りです。
Sample Explanation 2
人 $ 1,2,3,4 $ に渡す整数がそれぞれ $ (0,1,0,1),(1,0,1,0) $ のときの $ 2 $ 通りです。
Sample Explanation 3
$ 998244353 $ で割ったあまりを求めてください。
首先思考如果题目是在一个链中而不是一个环中如何计算,对于这个纯数学题我们易得
m
×
(
m
−
1
)
(
n
−
1
)
m\times \left ( m-1 \right ) ^{\left ( n-1 \right )}
m×(m−1)(n−1)
那么我们可不可以在链的结尾再加上开头求解呢?
于是我们得到了这样的式子
m
×
(
m
−
1
)
(
n
−
2
)
×
(
m
−
2
)
m\times \left ( m-1 \right ) ^{\left ( n-2 \right )}\times\left ( m-2 \right )
m×(m−1)(n−2)×(m−2)
之所以把最后一个元素的可能性改为m-2,是因为他还要避免与第一个元素的颜色相同.
于是我们就得到了这样的代码
#include<bits/stdc++.h>
using namespace std;
const long long Mod=998244353;
long long n,m;
long long pow_mod(long long a, long long b){
long long ret = 1;
while(b){
if(b & 1) ret = (ret * a) % Mod;
a = (a * a) % Mod;
b >>= 1;
}
return ret;
}
int main(){
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n>>m;
cout<<m%Mod*pow_mod(m-1,n-2)%Mod*(m-2)%Mod;
return 0;
}
吗?
如果倒数第二个元素和第一个元素的颜色相同,那么答案就是
m
×
(
m
−
1
)
(
n
−
1
)
m\times \left ( m-1 \right ) ^{\left ( n-1 \right )}
m×(m−1)(n−1)
所以我们要把这两种情况加起来
但是我们又无法保证n和m的给出可以使倒数第二个元素和第一个元素的颜色相同.
为了验证将他们加起来的猜想我又写了下面的程序比较n,m均<=10时,正确答案和相加的结果
#include <bits/stdc++.h>
using namespace std;
const int N = 2, NM = 1e6 + 10, Mod = 998244353;
struct Matrix
{
int a[N + 2][N + 2];
Matrix()
{
memset(a, 0, sizeof(a));
}
void init_I()
{
for (int i = 1; i <= N; i++)
{
a[i][i] = 1;
}
}
Matrix operator*(Matrix B) const
{
Matrix res;
for (int i = 1; i <= N; i++)
for (int j = 1; j <= N; j++)
for (int k = 1; k <= N; k++)
{
int x = 1ll * a[i][k] * B.a[k][j] % Mod;
(res.a[i][j] += x) %= Mod;
}
return res;
}
};
Matrix ksm(Matrix x, long long y)
{
Matrix res;
res.init_I();
while (y)
{
if (y & 1)
res = res * x;
x = x * x;
y >>= 1;
}
return res;
}
/// @brief
/// @return
long long pow_mod(long long a, long long b){
long long ret = 1;
while(b){
if(b & 1) ret = (ret * a) % Mod;
a = (a * a) % Mod;
b >>= 1;
}
return ret;
}
int main()
{
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
for (long long n = 2; n <= 10; n++)
{
for (long long m = 2; m <= 10; m++)
{
Matrix A, B, R;
A.a[1][2] = 1;
A.a[2][1] = (m - 1) % Mod;
A.a[2][2] = (m - 2) % Mod;
B.a[1][1] = 1;
R = ksm(A, n - 1) * B;
cout << n << " " << m << "--->" << R.a[2][1] % Mod * m % Mod << "; ";
///
cout<<m%Mod*pow_mod(m-1,n-2)%Mod*(m-2)%Mod+m%Mod*pow_mod(m-1,n-2)%Mod*(m-1)%Mod<<endl;
}
}
return 0;
}
输出
2 2--->2; 2
2 3--->6; 9
2 4--->12; 20
2 5--->20; 35
2 6--->30; 54
2 7--->42; 77
2 8--->56; 104
2 9--->72; 135
2 10--->90; 170
3 2--->0; 2
3 3--->6; 18
3 4--->24; 60
3 5--->60; 140
3 6--->120; 270
3 7--->210; 462
3 8--->336; 728
3 9--->504; 1080
3 10--->720; 1530
4 2--->2; 2
4 3--->18; 36
4 4--->84; 180
4 5--->260; 560
4 6--->630; 1350
4 7--->1302; 2772
4 8--->2408; 5096
4 9--->4104; 8640
4 10--->6570; 13770
5 2--->0; 2
5 3--->30; 72
5 4--->240; 540
5 5--->1020; 2240
5 6--->3120; 6750
5 7--->7770; 16632
5 8--->16800; 35672
5 9--->32760; 69120
5 10--->59040; 123930
6 2--->2; 2
6 3--->66; 144
6 4--->732; 1620
6 5--->4100; 8960
6 6--->15630; 33750
6 7--->46662; 99792
6 8--->117656; 249704
6 9--->262152; 552960
6 10--->531450; 1115370
7 2--->0; 2
7 3--->126; 288
7 4--->2184; 4860
7 5--->16380; 35840
7 6--->78120; 168750
7 7--->279930; 598752
7 8--->823536; 1747928
7 9--->2097144; 4423680
7 10--->4782960; 10038330
8 2--->2; 2
8 3--->258; 576
8 4--->6564; 14580
8 5--->65540; 143360
8 6--->390630; 843750
8 7--->1679622; 3592512
8 8--->5764808; 12235496
8 9--->16777224; 35389440
8 10--->43046730; 90344970
9 2--->0; 2
9 3--->510; 1152
9 4--->19680; 43740
9 5--->262140; 573440
9 6--->1953120; 4218750
9 7--->10077690; 21555072
9 8--->40353600; 85648472
9 9--->134217720; 283115520
9 10--->387420480; 813104730
10 2--->2; 2
10 3--->1026; 2304
10 4--->59052; 131220
10 5--->1048580; 2293760
10 6--->9765630; 21093750
10 7--->60466182; 129330432
10 8--->282475256; 599539304
10 9--->75497479; 268435454
10 10--->492051351; 1328476452
综上所述,O(1)和O(lgn)复杂度的算法不足以完成此题,接下来使用DP
首先确定状态,只有确定了状态和转移,DP才有可能成功.
假设f[i][j]表示前i个元素,第i个元素为j颜色时的方案数;
这种状态其实是可以确定转移的,but
N,M<=1e6
也就是说f数组(dp数组)的大小要到1e12,显然是不能接受的.
现在回想前面的推理,失败的原因是不是
我们又无法保证n和m的给出可以使倒数第二个元素和第一个元素的颜色相同
那我们由此获得启发,假设f[i][1或0]表示表示前i个元素,第i个元素的颜色与第一个元素的颜色 相同/不相同 时的方案数.
我们又能轻松 得到
f
[
i
]
[
1
]
=
f
[
i
−
1
]
[
0
]
f[i][1]=f[i-1][0]
f[i][1]=f[i−1][0]
f
[
i
]
[
0
]
=
f
[
i
−
1
]
[
1
]
×
(
m
−
1
)
+
f
[
i
−
1
]
[
0
]
×
(
m
−
2
)
f[i][0]=f[i-1][1]\times \left ( m-1 \right ) + f[i-1][0]\times \left ( m-2 \right )
f[i][0]=f[i−1][1]×(m−1)+f[i−1][0]×(m−2)
且
f
[
1
]
[
1
]
=
1
f[1][1]=1
f[1][1]=1
因为第一个元素的颜色不可能和第一个元素的颜色不同,所以
f
[
1
]
[
0
]
=
0
f[1][0]=0
f[1][0]=0
得到代码:
#include <bits/stdc++.h>
using namespace std;
const int MAXN = 1e6 + 10;
const int Mod = 998244353;
int n, m;
int f[MAXN][2];
int main() {
cin >> n >> m;
f[1][1] = m;
for (int i = 2; i <= n; i++) {
f[i][0] = (1ll * f[i - 1][1] * (m - 1) + 1ll * f[i - 1][0] * (m - 2)) % Mod;
f[i][1] = f[i - 1][0];
}
int ans = f[n][0];
printf("%d\n", ans);
return 0;
}
以下是题外的优化
事情到这里本应该结束了,刚才的代码24个点共耗时20ms,那有没有更好的方法既能加速又能提高n和m的上限
观察下面这两个式子(刚才的转移方程)
f
[
i
]
[
1
]
=
f
[
i
−
1
]
[
0
]
f[i][1]=f[i-1][0]
f[i][1]=f[i−1][0]
f
[
i
]
[
0
]
=
f
[
i
−
1
]
[
1
]
×
(
m
−
1
)
+
f
[
i
−
1
]
[
0
]
×
(
m
−
2
)
f[i][0]=f[i-1][1]\times \left ( m-1 \right ) + f[i-1][0]\times \left ( m-2 \right )
f[i][0]=f[i−1][1]×(m−1)+f[i−1][0]×(m−2)
是不是很眼熟
这还可以用矩阵优化啊
那就是找一个矩阵X,使得
X
×
[
f
[
i
]
[
1
]
f
[
i
]
[
0
]
]
=
[
f
[
i
+
1
]
[
1
]
f
[
i
+
1
]
[
0
]
]
X\times \begin{bmatrix} f[i][1]\\f[i][0] \end{bmatrix}= \begin{bmatrix} f[i+1][1] \\f[i+1][0] \end{bmatrix}
X×[f[i][1]f[i][0]]=[f[i+1][1]f[i+1][0]]
那根据转移方程就可以得到
X
=
[
0
1
m
−
1
m
−
2
]
X=\begin{bmatrix} 0 & 1\\ m-1 &m-2 \end{bmatrix}
X=[0m−11m−2]
也就是
[
0
1
m
−
1
m
−
2
]
×
[
f
[
i
]
[
1
]
f
[
i
]
[
0
]
]
=
[
f
[
i
+
1
]
[
1
]
f
[
i
+
1
]
[
0
]
]
\begin{bmatrix} 0 & 1\\ m-1 &m-2 \end{bmatrix}\times \begin{bmatrix} f[i][1]\\f[i][0] \end{bmatrix}= \begin{bmatrix} f[i+1][1] \\f[i+1][0] \end{bmatrix}
[0m−11m−2]×[f[i][1]f[i][0]]=[f[i+1][1]f[i+1][0]]
再加上取余运算之类的小细节就得到了24个点总耗时5ms的空间时间都优化过的程序
#include<cctype>
#include<cerrno>
#include<cfloat>
#include<climits>
#include<clocale>
#include<cmath>
#include<csetjmp>
#include<csignal>
#include<cstdarg>
#include<cstddef>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<ctime>
#include<ccomplex>
#include<cfenv>
#include<cinttypes>
#include<cstdbool>
#include<cstdint>
#include<ctgmath>
#include<cwchar>
#include<cwctype>
#include<algorithm>
#include<bitset>
#include<complex>
#include<deque>
#include<exception>
#include<fstream>
#include<functional>
#include<iomanip>
#include<ios>
#include<iosfwd>
#include<iostream>
#include<istream>
#include<iterator>
#include<limits>
#include<list>
#include<locale>
#include<map>
#include<memory>
#include<new>
#include<numeric>
#include<ostream>
#include<queue>
#include<set>
#include<sstream>
#include<stack>
#include<stdexcept>
#include<streambuf>
#include<string>
#include<typeinfo>
#include<utility>
#include<valarray>
#include<vector>
#include<array>
#include<atomic>
#include<chrono>
#include<condition_variable>
#include<forward_list>
#include<future>
#include<initializer_list>
#include<mutex>
#include<random>
#include<ratio>
#include<regex>
#include<scoped_allocator>
#include<system_error>
#include<thread>
#include<tuple>
#include<typeindex>
#include<type_traits>
#include<unordered_map>
#include<unordered_set>
using namespace std;
const int N = 2,NM=1e6+10,Mod=998244353;
struct Matrix
{
int a[N + 2][N + 2];
Matrix()
{
memset(a, 0, sizeof(a));
}
void init_I()
{
for (int i = 1; i <= N; i++)
{
a[i][i] = 1;
}
}
Matrix operator*(Matrix B) const
{
Matrix res;
for (int i = 1; i <= N; i++)
for (int j = 1; j <= N; j++)
for (int k = 1; k <= N; k++)
{
int x = 1ll * a[i][k] * B.a[k][j] % Mod;
(res.a[i][j] += x) %= Mod;
}
return res;
}
};
Matrix ksm(Matrix x, long long y)
{
Matrix res;
res.init_I();
while (y)
{
if (y & 1)
res = res * x;
x = x * x;
y >>= 1;
}
return res;
}
long long n,m;
int main(){
ios::sync_with_stdio(0);
cin.tie(0);
cout.tie(0);
cin>>n>>m;
Matrix A,B,R;
A.a[1][2]=1;A.a[2][1]=(m-1)%Mod;A.a[2][2]=(m-2)%Mod;B.a[1][1]=1;
//A.a[2][1]=1;A.a[1][2]=m-1;A.a[2][2]=m-2;B.a[1][1]=1;
//A.a[0][1]=m-1;A.a[1][0]=1;A.a[1][1]=m-2;B.a[1][0]=1;
R=ksm(A,n-1)*B;
cout<<R.a[2][1]%Mod*m%Mod;
//cout<<R.a[0][1]%Mod;
return 0;
}