题解/算法 {LCP 76. 魔法棋盘}
LINK: https://leetcode.cn/problems/1ybDKD/
;
注意, 他说 两个异色棋子之间 如果只有1个棋子, 那么这个序列就是非法的; 请仔细理解这句话, 你可能觉得 只有R {R/B} B
这种情况 是非法的, 其实不是的, R ... {R/B} ... B
(...
指若干个空地) 这其实也是非法的, 也就是 他俩之间只有1个棋子 并不意味着 他俩之间就1个位置 不是这样的, 而是他俩之间的所有位置上 只有1个棋子 而其他都是空位;
@DELI;
序列: 由{.
(空位), R, B}组成的序列
合法序列: 将序列中所有的.
给删除掉后, 该序列必须属于ST7
中的一种;
合法序列的ST7值: [0: 无R和B; 1: R; 2: B; 3: RR…; 4: BB…; 5: …RB; 6: …BR];
.
注意, 不要把1,3
情况 混为一谈, 否则 如果我们令st
同时表示1,3
情况, 则此时再加一个B
, 他可能变成RB
(合法), 也可能变成RRB
(非法), 这就分不清了;
合法矩阵: 每个{行/列} 都是合法序列;
@DELI;
联想到状压DP, 通常对于M*N
的二维网格, 状压DP以单个行为单位, 即DP(i,?)
表示的是
i
∗
N
i*N
i∗N的合法矩阵A, 然后往他的下面 再添加一行Row, 就变成了
(
i
+
1
)
∗
N
(i+1)*N
(i+1)∗N的矩阵B;
.
也就是, A + Row -> B
, 但是注意 A一定是合法矩阵, 但B不一定是合法的 (比如, A的某列是RR...
, 而Row的这一列是B
棋子, 那么新组成的B矩阵 就不是合法的)
确定DP(i,?)
中的?
的定义 非常重要;
我们有必要将?
定义为 第i行的行状态吗? 其实完全没有必要的, 因为状压DP此时是以每一行为单位, 即行与行之间是独立的; DP的转移 根本不会用到 上一行的行状态;
.
比如上面的A + Row -> B
, 如果我们只考虑让B
矩阵的所有行 都是合法的 (暂不考虑让B的列是否合法), 那么 你只要保证Row
这个行是合法的, 那么B的所有行也会是合法的;
也就是, 你只需要保证: DP的转移 要使得B的所有列是合法的; 因此?
需要记录A矩阵的所有列的状态 (即ST7), 即 他是个长度为N
的7进制数;
元素的ST3值: [0: 空地; 1: R; 2: B], 让Row为长度为N
的3进制数(ST3) 且要保证 该行 是合法序列;
用Trans_7_1[ 7][ 3]
来表示: a = ?[b][c]
一个序列的ST7状态为b
往序列最后添加一个ST3为c
的元素后, 新序列的ST7值 (如果非法, 则为-1
);
那么, 对于DP(i, j)
(注意, j
是个长度为N的7进制数), 令vector<int> R
为 第i+1
行 所有合法的ST3值 (即长度为N的3进制数);
auto cur = DP(i,j);
for( auto row : R){
auto merge = Trans_7_3[ j][ row];
if( merge == -1) continue;
DP( i + 1, merge) += cur;
}
c = Trans_7_3[ a][ b]
的值为: 令a = [a1, a2,..., aN]
(长度为N的7进制数, 左侧是低位), 同理b = [b1, ..., bN]
(3进制数), 则c = [c1, ..., cN]
其中ci = Trans_7_1[ ai][ bi]
一旦ci = -1
则c = -1
(这是个特判);
@DELI;
由于DP的定义 是用N个列进行的, 因此我们要让N尽量的小 (如果N > M
则让矩阵转置);
DP转移是:
M
∗
7
N
∗
3
N
∗
?
M * 7^N * 3^N * ?
M∗7N∗3N∗? (?
获得新的矩阵的列状态的时间), 令M=6, N=5
此时已经达到1e7
的水平, 因此?
必须是常数, 即我们预处理这个Trans_7_3
数组;
这个R(即i+1
行的所有合法的行状态ST3) 的获得, 你可能会通过暴力的方式, 即遍历
3
N
3^N
3N的所有状态st 如果他可以与Map
(这行的地图)匹配, 然后看这个状态 他的ST7状态 是否合法 (即去掉st这个3进制的所有0
, 他变成一个由1,2
组成的序列, 看这个序列是否符合ST7);
.
这太暴力了, 上面默认他是3^N
而如果你用这种暴力方式 比3^N
还高, 会超时的;
.
以前讲过这种优化 即用DFS剪枝优化, 而此时正好是可以的, 因为 如果当前是RBB
他已经非法了 就没必要再往下遍历了, 因此用DFS( col, st3, st7)
表示 遍历到一个序列的col
位置, 前[0....,col-1]
表示了一个3进制数为st3
, 且这个st3
所对应的ST7值 是st7
(他用来判断 序列是否合法);
对于预处理Trans_7_3
他是时间 也是非常大的!
{ // Trans_7_3
for( int col_st7 = 0; col_st7 < MAX_ST7_5; ++col_st7){
for( int row_st3 = 0; row_st3 < MAX_ST3_5; ++row_st3){
auto & cur = Trans_7_3[ col_st7][ row_st3];
cur = 0;
for( int i = 0; i < 5; ++i){
auto merge = Trans_7_1[ Tools::Get_radix_bit( col_st7, i, Pow7)][ Tools::Get_radix_bit( row_st3, i, Pow3)];
if( merge == -1){ cur = -1; break;}
Tools::Set_radix_bit( cur, i, merge, Pow7);
}
}
}
}
优化1: 对于M*N
的矩阵, Trans_7_3[ a][ b]
表示: a
是长度N的7进制数, b
是长度为N的3进制数, 显然N在这里是个变量; 因此你会认为 对于每组数据 都进行预处理;
.
其实, 我们就统一按照 N=5
的标准 进行预处理, 就可以了 (这个优化确实很难…), 因为 如果当前的N=3 != 5
, 那么他用的st7
应该是长度为3的7进制数 但实际上他是长度为5的7进制数, 有什么影响吗? 他的2个高位(按理说 他不应该存在) 是0
, 而st7
和st3
这多余的2个高位 都是0, 他俩merge 结果 也是0;
还有个更大的优化, 可以把里面的for(5)
给去掉, 即我们以st7, st3
的最后一位 做拆分, st7: [a1,...,a5], st3: [b1,...,b5]
, 拆分成 A= [a2...,a5], B= [b2,...,b5]
与 a1, b1
两个部分, a1,b1
通过Trans_7_1[ st7 % 7][ st3 % 3]
获得, A,B
的匹配 通过Trans_7_3[ st7/ 7][ st3/ 3]
(这非常重要, 即此时a2...a5
会往低位移动1位, 即a2
会变成最低位(第1位) 然后最高位(第5位)会补上0, 但因为merge( 0, 0) = 0
(这在本题是成立的), 因此他的结果 你还需要将他往高位移动1 即*= 7
;
auto & cur = Trans_7_3[ col_st7][ row_st3];
if( col_st7 == 0 && row_st3 == 0){ cur = 0; continue;}
auto merge = Trans_7_1[ col_st7 % 7][ row_st3 % 3];
auto hig = Trans_7_3[ col_st7 / 7][ row_st3 / 3];
if( merge == -1 || hig == -1) cur = -1;
else cur = hig * 7 + merge;
@DELI;
代码
constexpr int MAX_ST7_5 = 7 * 7 * 7 * 7 * 7, MAX_ST3_5 = 3 * 3 * 3 * 3 * 3;
// 序列: 由{., R, B}组成的序列
// 合法序列: 将序列中所有的`.`给删除掉后, 该序列必须属于`ST7`中的一种;
// 合法序列的ST7值: [0: 无R和B, 1: R, 2: B, 3: RR..., 4: BB..., 5: ...RB, 6: ...BR];
// 元素的ST3值: [0: 空地, 1: R, 2: B];
// 序列的ST3值: 令序列长度为N, 则每个元素的ST3值 所组成的长度为N的3进制数, 即为该序列的ST3值;
// 合法矩阵: 每个{行/列} 都是合法序列;
// 矩阵的列状态: `i*N`的矩阵 每个列的ST7值分别为`[c1,...,cN]`, 则`B`的7进制数为`[cN,...,c1]`(左侧为高位), 往该`i*N`矩阵的下面 再添加一行 该行的ST3状态为`[cc1,...,ccN]`, cci=[0,3), 则`C`的3进制数为`[ccN,...,cc1]` (左侧为高位), 此时新的`(i+1)*N`矩阵的列状态7进制数为`A`;
int Trans_7_1[ 7][ 3] = {
{0, 1, 2}, // x [x/ R/ B]
{1, 3, 5}, // R [x/ R/ B]
{2, 6, 4}, // B [x/ R/ B]
{3, 3, -1}, // RR... [x/ R/ B]
{4, -1, 4}, // BB... [x/ R/ B]
{5, 6, -1}, // ...RB [x/ R/ B]
{6, -1, 5}, // ...BR [x/ R/ B]
};
int Pow7[ 6], Pow3[ 6];
// `A = ?[B][C]`: ST7值为`B`的合法序列, 在其末尾添加一个(ST3值为`C`的元素)后, 新序列的ST7值;
int Trans_7_3[ MAX_ST7_5][ MAX_ST3_5]; // `A = ?[B][C]`: 对于一个`(M+1)*N`的矩阵, 其上面的`M*N`子矩阵 且其列状态为`B`, 第`M+1`行的ST3值为`C`
long long DP[ 31][ MAX_ST7_5];
void __Initialize(){
Pow3[ 0] = 1;
for( int i = 1; i <= 5; ++i) Pow3[ i] = Pow3[ i - 1] * 3;
Pow7[ 0] = 1;
for( int i = 1; i <= 5; ++i) Pow7[ i] = Pow7[ i - 1] * 7;
{ // Trans_7_3
//> 这非常非常重要, 是最大的难点, 就预处理`N=5`的情况, 不需要根据不同的N值 来预处理; 否则会超时;
for( int col_st7 = 0; col_st7 < MAX_ST7_5; ++col_st7){
for( int row_st3 = 0; row_st3 < MAX_ST3_5; ++row_st3){
// auto & cur = Trans_7_3[ col_st7][ row_st3];
// if( col_st7 == 0 && row_st3 == 0){ cur = 0; continue;}
// auto merge = Trans_7_1[ col_st7 % 7][ row_st3 % 3];
// auto hig = Trans_7_3[ col_st7 / 7][ row_st3 / 3];
// if( merge == -1 || hig == -1) cur = -1;
// else cur = hig * 7 + merge;
auto & cur = Trans_7_3[ col_st7][ row_st3];
cur = 0;
for( int i = 0; i < 5; ++i){
auto merge = Trans_7_1[ Tools::Get_radix_bit( col_st7, i, Pow7)][ Tools::Get_radix_bit( row_st3, i, Pow3)];
if( merge == -1){ cur = -1; break;}
Tools::Set_radix_bit( cur, i, merge, Pow7);
}
}
}
}
}
class Solution {
public:
long long getSchemeCount(int M, int N, vector<string>& C) {
//> 这非常重要, 否则会超时;
{ static bool __is_first = true; if( __is_first){ __is_first = false; __Initialize();} }
if( M < N){ // 将`C`转置
swap( M, N);
auto temp = C;
C.resize( M);
for( int i = 0; i < M; ++i){
C[ i].resize( N);
for( int j = 0; j < N; ++j){
C[ i][ j] = temp[ j][ i];
}
}
}
{ // DP
int dfs_row;
vector< int> dfs_valid_row_st3;
function<void(int,int,int)> Dfs = [&]( int _col, int _row_st3, int _row_st7){
//< `row_st7: [-1,0,...,6]`;
if( _row_st7 == -1) return;
if( _col == N){
dfs_valid_row_st3.push_back( _row_st3);
return;
}
auto cur = C[ dfs_row][ _col];
if( cur == '.' || cur == '?'){
int id = 0;
auto nex_row_st3 = _row_st3;
Tools::Set_radix_bit( nex_row_st3, _col, id, Pow3);
Dfs( _col + 1, nex_row_st3, Trans_7_1[ _row_st7][ id]);
}
if( cur == 'R' || cur == '?'){
int id = 1;
auto nex_row_st3 = _row_st3;
Tools::Set_radix_bit( nex_row_st3, _col, id, Pow3);
Dfs( _col + 1, nex_row_st3, Trans_7_1[ _row_st7][ id]);
}
if( cur == 'B' || cur == '?'){
int id = 2;
auto nex_row_st3 = _row_st3;
Tools::Set_radix_bit( nex_row_st3, _col, id, Pow3);
Dfs( _col + 1, nex_row_st3, Trans_7_1[ _row_st7][ id]);
}
};
{ // DP(0, ?)
dfs_row = 0;
dfs_valid_row_st3.clear();
Dfs( 0, 0, 0);
memset( DP[ 0], 0, sizeof( DP[ 0]));
for( auto row_st3 : dfs_valid_row_st3){
auto new_col_st7 = Trans_7_3[ 0][ row_st3];
if( new_col_st7 == -1) continue;
DP[ 0][ new_col_st7] += 1;
}
}
{ // DP(>0, ?)
for( int row = 1; row < M; ++row){
dfs_row = row;
dfs_valid_row_st3.clear();
Dfs( 0, 0, 0);
memset( DP[ row], 0, sizeof( DP[ row]));
for( int col_st7 = 0; col_st7 < Pow7[ N]; ++col_st7){
for( auto row_st3 : dfs_valid_row_st3){
auto new_col_st7 = Trans_7_3[ col_st7][ row_st3];
if( new_col_st7 == -1) continue;
DP[ row][ new_col_st7] += DP[ row - 1][ col_st7];
}
}
}
}
}
return accumulate( DP[ M - 1], DP[ M - 1] + Pow7[ N], 0LL);
}
};