When doing Bitmasking DP, we are always handling problems like "what is the i-th bit in the state" or "what is the number of valid bits in a state". Here are some tricks
- We can use (x >> i) & 1 to get i-th bit in state x
- We can use (x & y) == x to check if x is a subset of y. The subset means every state in x could be 1 only if the corresponding state in y is 1.
- We can use (x & (x >> 1)) == 0 to check if there are no adjancent valid states in x.
Solution for 1349
https://leetcode.com/problems/maximum-students-taking-exam/
steps:
- show the validity of each row in bit representation of the seats
- dp[i][mask] as the maximum number of students for the first i rows while the students in the i-th row follow the masking mask. There should be no adjancent valid states in the mask.
- initialize dp[m][k] where m is the row number and k is the number of states, set each to be -1 to show that it is invalid
- for 2^n states, check if the current state is a subset of the validity and has no adjacent seats
- if the row is 1, we directly count the 1s for dp[0][j] because there is no previous row
- otherwise, we need to go over the 2^n states in the previous row to find the max for dp[i][j] for each j (need to check: 1) if the state has no conflict with the current row => not top left & not top right 2) if the state is valid in the previous row)
- record the max in the process
dp[i][mask] = max(dp[i - 1][mask']) + number of valid bits(mask)
class Solution {
public int maxStudents(char[][] seats) {
int m = seats.length;
int n = seats[0].length;
int[] validity = new int[m];
//construct the validity for each row
//for example: 101101
for(int i = 0; i < m; i++){
for(int j = 0; j < n; j++){
validity[i] = (validity[i] << 1) + (seats[i][j] == '.' ? 1 : 0);
}
}
int stateSize = 1 << n;;
int[][] dp = new int[m][stateSize];
for (int i = 0; i < m; i++) Arrays.fill(dp[i], -1);
int ans = 0;
for(int i = 0; i < m; i++){
for(int j = 0; j < stateSize; j++){
if (((j & validity[i]) == j) && ((j & (j >> 1)) == 0)) {
if(i == 0){
dp[i][j] = Integer.bitCount(j);
}
else{
for(int k = 0; k < stateSize; k++){
if ((j & (k >> 1)) == 0 && ((j >> 1) & k) == 0 && dp[i-1][k] != -1) {
dp[i][j] = Math.max(dp[i][j], dp[i-1][k] + Integer.bitCount(j));
}
}
}
}
ans = Math.max(ans, dp[i][j]);
}
}
return ans;
}
}