知识点 - SOS-DP
解决问题类型:
全称 S u m o v e r S u b s e t s d y n a m i c p r o g r a m m i n g Sum\ over\ Subsets\ dynamic\ programming Sum over Subsets dynamic programming ,子集和dp
F [ m a s k ] = ∑ i ∈ m a s k A [ i ] F[mask]= \sum _{ i∈mask } A[i] F[mask]=∑i∈maskA[i]
实现
O ( 4 n ) O(4^n) O(4n)
for(int mask =0; mask < (1<<N); ++i) {
for(int i = 0; i < (1<<N); ++i)
if(i&mask == i)
F[mask] += A[i];
}
O ( 3 n ) O(3^n) O(3n) 枚举子集
for(int mask = 0; mask < (1<<N); ++mask) {
F[mask] = A[0];
for(int i = mask; i > 0; i = (i-1)&mask)
F[mask] += F[i];
}
O ( n ∗ 2 n ) O(n*2^n) O(n∗2n) SOS-DP 类似于 高维前缀和
d p [ m a s k ] [ i ] dp[mask][i] dp[mask][i] 代表 x & m a s k = x , x ∧ m a s k < 2 i + 1 x\&mask=x,x∧mask<2^{i+1} x&mask=x,x∧mask<2i+1 的 A [ i ] A[i] A[i]的和,意思就是 d p [ m a s k ] [ i ] dp[mask][i] dp[mask][i]是和 m a s k mask mask只有前 i i i个位不同的 A [ x ] A[x] A[x]的和;
if (mask&(1<<i)==0) dp[mask][i] = dp[mask][i-1]
else dp[mask][i] = dp[mask][i-1]+dp[mask^(1<<i)][i-1];
例题
- CF1208F
题意:求 m a x ( a i ∣ ( a j & a k ) ) , 1 ≤ i < j < k ≤ n max(a_i|(a_j\&a_k)) ,1≤i<j<k≤n max(ai∣(aj&ak)),1≤i<j<k≤n
分析: 从后往前依次将每一个数加入 S O S d p SOSdp SOSdp中,记忆化一下
参考代码:code
代码
//iterative version
for(int mask = 0; mask < (1<<N); ++mask){
dp[mask][-1] = A[mask]; //handle base case separately (leaf states)
for(int i = 0;i < N; ++i){
if(mask & (1<<i))
dp[mask][i] = dp[mask][i-1] + dp[mask^(1<<i)][i-1];
else
dp[mask][i] = dp[mask][i-1];
}
F[mask] = dp[mask][N-1];
}
//memory optimized, super easy to code.
for(int i = 0; i<(1<<N); ++i)
F[i] = A[i];
for(int i = 0;i < N; ++i) for(int mask = 0; mask < (1<<N); ++mask){
if(mask & (1<<i))
F[mask] += F[mask^(1<<i)];
}
//CF1208F
#include <bits/stdc++.h>
using namespace std;
#define rep(i,j,k) for(int i = (int)j;i <= (int)k;i ++)
#define debug(x) cerr<<#x<<":"<<x<<endl
#define pb push_back
typedef long long ll;
typedef pair<int,int> pi;
const int MAXN = (int)1<<21;
int dp[MAXN][21],a[MAXN];
void update(int mask,int num) {
if (num > 20) return;
if (dp[mask][num] >= 2) return;
dp[mask][num] ++;
update(mask,num+1);
if (mask>>num&1) update(1<<num^mask,num+1);
}
int main()
{
int N;
scanf("%d",&N);
rep(i,1,N) scanf("%d",&a[i]);
int ans = 0;
for(int i = N;i >= 1;i --) {
int now = 0,tmpAns = 0;
for(int j = 20;j >= 0;j --) {
if (a[i]>>j&1) {
tmpAns |= 1<<j;
}else {
if (dp[now][20] >= 2) now |= 1<<j,tmpAns |= 1<<j;
}
}
if (i <= N-2) ans = max(ans,tmpAns);
update(a[i],0);
}
printf("%d\n",ans);
}