一 原题
On one of the planets of Solar system, in Atmosphere University, many students are fans of bingo game.
It is well known that one month on this planet consists of n2n2 days, so calendars, represented as square matrix nn by nn are extremely popular.
Weather conditions are even more unusual. Due to the unique composition of the atmosphere, when interacting with sunlight, every day sky takes one of three colors: blue, green or red.
To play the bingo, you need to observe the sky for one month — after each day, its cell is painted with the color of the sky in that day, that is, blue, green or red.
At the end of the month, students examine the calendar. If at least one row or column contains only cells of one color, that month is called lucky.
Let's call two colorings of calendar different, if at least one cell has different colors in them. It is easy to see that there are 3n⋅n3n⋅n different colorings. How much of them are lucky? Since this number can be quite large, print it modulo 998244353998244353.
The first and only line of input contains a single integer nn (1≤n≤10000001≤n≤1000000) — the number of rows and columns in the calendar.
Print one number — number of lucky colorings of the calendar modulo 998244353998244353
1
3
2
63
3
9933
In the first sample any coloring is lucky, since the only column contains cells of only one color.
In the second sample, there are a lot of lucky colorings, in particular, the following colorings are lucky:
While these colorings are not lucky:
二 分析
给n*n个格子,每个格子可以填3种颜色。问有多少种填色方案,至少有一列或一行是同色的。根据容斥原理,答案应该等于:
其中f(i, j)为前i行和前j列每行/列同色的方案数:i, j中有一个为0的时候:;否则
对于前一种情况,我们O(n)的把每个f(i,j)算出来,求出和式里对应的即可。对于后一种情况,和式里有两个循环变量,直接求和的复杂度是O(n^2)的,要做一些数学推导让复杂度降到O(n):
和式中的组合数下标都为n,O(n)预处理一下,所有的幂次都用快速幂求,注意处理底数为负的情况。
三 代码
/*
AUTHOR: maxkibble
LANG: c++
PROB: cf 997C
*/
#include <cstdio>
typedef long long LL;
const int maxn = 1e6 + 5;
const int mod = 998244353;
LL n, c[maxn], ans;
LL fastPow(LL a, LL x) {
LL ret = 1, base = a;
while (x) {
if (x & 1) ret = ret * base % mod;
x >>= 1;
base = (base * base) % mod;
}
return (ret % mod + mod) % mod;
}
int main() {
scanf("%lld", &n);
c[0] = 1;
for (int i = 1; i <= n; i++) {
c[i] = c[i - 1] * (n - i + 1) % mod * fastPow(i, mod - 2) % mod;
}
for (int i = 0; i < n; i++) {
LL s = fastPow(3, i);
LL t1 = fastPow(1 - s, n), t2 = fastPow(-s, n);
LL d = (c[i] * (t1 - t2) % mod + mod) % mod;
if (i & 1) ans += d;
else ans -= d;
ans = (ans % mod + mod) % mod;
}
ans = (ans * 3) % mod;
for (int i = 1; i <= n; i++) {
LL d = fastPow(3, i) * fastPow(3, n * (n - i)) % mod;
d = c[i] * d % mod;
if (i & 1) ans += 2 * d;
else ans -= 2 * d;
ans = (ans % mod + mod) % mod;
}
printf("%lld\n", ans);
return 0;
}