problem
solution
骨牌横着放会占用一行两列,骨牌竖着放会占用两行一列。
问题可以抽象为:每次可以选择连续的两行放 A A A,或选一行放 B B B;每次可以选一列放 B B B ,或选连续的两列放 A A A。
且最后行的 A A A 数量等于列的 B B B 数量,列的 A A A 数量等于行的 B B B 数量。【多米诺骨牌大小的匹配】
这样一个放物品的过程会对应原来很多个多米诺骨牌摆放的方案,因为行列之间 A , B A,B A,B 配对的方案不只一种,任意一种配对方案对应一种不同的摆放方案。
行列之间的 A , B A,B A,B 的摆放是不影响的,只有数量相等的要求。
假设选了 x x x 个 A A A, y y y 个 B B B,则方案数为 x ! y ! x!y! x!y!。
由于行列互不影响,所以分开求解,最后利用乘法原理计算即可。下面以行为例。
设 d p i , j , k : dp_{i,j,k}: dpi,j,k: 考虑前 i i i 行,放了 j j j 个 A A A, k k k 个 B B B 的方案数。
暴力的状态都是
n
3
n^3
n3 的,考虑优化。
实际上,
B
B
B 的性质是只占用一行,只要知道剩下的空位,我们就能计算
B
B
B 的方案数。
改写 d p i , j : dp_{i,j}: dpi,j: 前 i i i 行放了 j j j 个 A A A ,还没有放 B B B 的方案数。
那么最后可以放 B B B 的空位为 c n t − 2 ∗ j cnt-2*j cnt−2∗j, c n t : cnt: cnt: 去除掉题目所给多米诺骨牌占用的行后剩下的行数,如果要选 k k k 个 B B B,则可以 O ( 1 ) O(1) O(1) 计算方案数为: ( c n t − 2 ∗ j k ) \binom{cnt-2*j}{k} (kcnt−2∗j)
最后合并需要枚举行选了多少个 A A A 和多少个 B B B,枚举总时间也是 O ( n 2 ) O(n^2) O(n2) 的。
所以本题最终的时间复杂度为: O ( n 2 ) O(n^2) O(n2)。
code
#include <cstdio>
#define int long long
#define mod 998244353
#define maxn 3605
int n, m, k, cnt_c, cnt_r;
int fac[maxn];
bool row[maxn], col[maxn];
int c[maxn][maxn], f[maxn][maxn], g[maxn][maxn];
signed main() {
scanf( "%lld %lld %lld", &n, &m, &k );
cnt_r = n, cnt_c = m;
for( int i = 1, X1, Y1, X2, Y2;i <= k;i ++ ) {
scanf( "%lld %lld %lld %lld", &X1, &Y1, &X2, &Y2 );
row[X1] = row[X2] = col[Y1] = col[Y2] = 1;
cnt_r --, cnt_c --;
if( X1 ^ X2 ) cnt_r --;
else cnt_c --;
}
fac[0] = f[0][0] = g[0][0] = c[0][0] = 1;
for( int i = 1;i < maxn;i ++ ) fac[i] = fac[i - 1] * i % mod;
for( int i = 1;i < maxn;i ++ ) {
c[i][0] = c[i][i] = 1;
for( int j = 1;j < i;j ++ )
c[i][j] = ( c[i - 1][j] + c[i - 1][j - 1] ) % mod;
}
for( int i = 1;i <= n;i ++ )
for( int j = 0;j <= i / 2;j ++ ) {
f[i][j] = f[i - 1][j];
if( i >= 2 and j and ! row[i] and ! row[i - 1] )
f[i][j] = ( f[i][j] + f[i - 2][j - 1] ) % mod;
}
for( int i = 1;i <= m;i ++ )
for( int j = 0;j <= i / 2;j ++ ) {
g[i][j] = g[i - 1][j];
if( i >= 2 and j and ! col[i] and ! col[i - 1] )
g[i][j] = ( g[i][j] + g[i - 2][j - 1] ) % mod;
}
int ans = 0;
for( int i = 0;i <= cnt_r / 2;i ++ )
for( int j = 0;j <= cnt_c / 2;j ++ )
if( i * 2 + j <= cnt_r and i + j * 2 <= cnt_c )
ans = ( ans + f[n][i] * g[m][j] % mod * c[cnt_r - 2 * i][j] % mod * c[cnt_c - 2 * j][i] % mod * fac[i] % mod * fac[j] ) % mod;
printf( "%lld\n", ans );
return 0;
}