题意:给一个 N × M N \times M N×M的矩阵,可以进行任意多次操作将一列轮换,求每一行的最大值之和的最大值。多组数据。
Easy Version N ≤ 4 N \leq 4 N≤4, M ≤ 100 M \leq100 M≤100
Hard Version N ≤ 12 N \leq 12 N≤12, M ≤ 2000 M \leq2000 M≤2000
看这数据范围显然是个状压
相当于是每一行只能选一个
设 f ( i , S ) f(i,S) f(i,S)表示当前到 i i i,已经选了 S S S的最大值
记忆化搜索一波,暴力转一下
复杂度 O ( 4 n n 2 m ) O(4^nn^2m) O(4nn2m) 可以通过Easy Version
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
using namespace std;
int n,m;
int a[20][105],dp[105][1<<4];
int Move(int s){return (s>>1)|((s&1)<<(n-1));}
int dfs(int pos,int s)
{
if (dp[pos][s]) return dp[pos][s];
if (s==(1<<n)-1) return 0;
if (pos==m) return 0;
int &ans=dp[pos][s];
for (int t=0;t<(1<<n);t++)
{
if (s&t) continue;
int res=dfs(pos+1,s|t);
int mx=0;
for (int k=0,cur=t;k<n;k++)
{
int sum=0;
cur=Move(cur);
for (int i=0;i<n;i++)
if ((1<<i)&cur)
sum+=a[i][pos];
mx=max(mx,sum);
}
ans=max(ans,res+mx);
}
return ans;
}
void solve()
{
memset(dp,0,sizeof(dp));
scanf("%d%d",&n,&m);
for (int i=0;i<n;i++)
for (int j=0;j<m;j++)
scanf("%d",&a[i][j]);
printf("%d\n",dfs(0,0));
}
int main()
{
int T;
scanf("%d",&T);
while (T--) solve();
return 0;
}
复杂度里有 M M M,过不了HV
考虑如何消掉M
我们发现因为只有 N N N行,所以最多只有 N N N列选了数
能否快速确定这 N N N列?
贪心!
我们把列按最大值从大到小排序,如果前面的一列没选数而后面的选了
那我们完全可以改成前面没有选的,一定更优
因为只能选 N N N个数,所以只用考虑最大的 N N N列
复杂度 O ( 4 n n 3 ) O(4^nn^3) O(4nn3)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
int n,m;
int a[20][2005],dp[2005][1<<12],ord[2005];
int Move(int s){return (s>>1)|((s&1)<<(n-1));}
int dfs(int pos,int s)
{
if (dp[pos][s]) return dp[pos][s];
if (s==(1<<n)-1) return 0;
if (pos==m) return 0;
int &ans=dp[pos][s];
for (int t=0;t<(1<<n);t++)
{
if (s&t) continue;
int res=dfs(pos+1,s|t);
int mx=0;
// printf("%d",t);
for (int k=0,cur=t;k<n;k++)
{
int sum=0;
cur=Move(cur);
// printf("->%d",cur);
for (int i=0;i<n;i++)
if ((1<<i)&cur)
sum+=a[i][ord[pos]];
mx=max(mx,sum);
}
// puts("");
ans=max(ans,res+mx);
}
return ans;
}
int mx[2005];
inline bool cmp(const int& a,const int& b){return mx[a]>mx[b];}
void solve()
{
memset(dp,0,sizeof(dp));
memset(mx,0,sizeof(mx));
scanf("%d%d",&n,&m);
for (int i=0;i<n;i++)
for (int j=0;j<m;j++)
scanf("%d",&a[i][j]),mx[j]=max(mx[j],a[i][j]);
for (int i=0;i<m;i++) ord[i]=i;
sort(ord,ord+m,cmp);
m=min(n,m);
printf("%d\n",dfs(0,0));
}
int main()
{
int T;
scanf("%d",&T);
while (T--) solve();
return 0;
}
然而发现慢得飞起
我们发现:在dp的时候,会花费 O ( n ) O(n) O(n)找出最佳的旋转方案,而这个 O ( n ) O(n) O(n)是在 O ( 4 n ) O(4^n) O(4n)的基础上的
为什么不预处理呢
然后得到了 O ( 2 n n 3 + 4 n n ) O(2^nn^3+4^nn) O(2nn3+4nn)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
int n,m;
int a[20][2005],mem[2005][1<<12],dp[2005][1<<12],ord[2005];
int dfs(int pos,int s)
{
if (dp[pos][s]) return dp[pos][s];
if (s==(1<<n)-1) return 0;
if (pos==m) return 0;
int &ans=dp[pos][s];
for (int t=0;t<(1<<n);t++)
if (!(s&t))
ans=max(ans,dfs(pos+1,s|t)+mem[ord[pos]][t]);
return ans;
}
int mx[2005];
inline bool cmp(const int& a,const int& b){return mx[a]>mx[b];}
void solve()
{
memset(dp,0,sizeof(dp));
memset(mx,0,sizeof(mx));
memset(mem,0,sizeof(mem));
scanf("%d%d",&n,&m);
for (int i=0;i<n;i++)
for (int j=0;j<m;j++)
scanf("%d",&a[i][j]),mx[j]=max(mx[j],a[i][j]);
for (int i=0;i<m;i++) ord[i]=i;
sort(ord,ord+m,cmp);
m=min(n,m);
for (int pos=0;pos<m;pos++)
for (int s=0;s<(1<<n);s++)
for (int k=0;k<n;k++)
{
int sum=0;
for (int i=0;i<n;i++)
if (s&(1<<i))
sum+=a[(i+k)%n][ord[pos]];
mem[ord[pos]][s]=max(mem[ord[pos]][s],sum);
}
printf("%d\n",dfs(0,0));
}
int main()
{
int T;
scanf("%d",&T);
while (T--) solve();
return 0;
}
然而还是慢得飞起
这玩意还有多组数据
我们发现这一句
if (!(s&t))
是求和 s s s没有交集的 t t t
能否快速得到这玩意呢?
还真不行
但我们可以换一种实现方式
直接递推实现dp
这样从上一维转移过来只用枚举子集
可以用这句
for (int t=s;t;t=((t-1)&s))
这样就只枚举了 s s s的子集
复杂度?
加上枚举 s s s
一共是
∑ i = 0 n C n i 2 i \sum_{i=0}^{n}C_n^i2^i i=0∑nCni2i
= ∑ i = 0 n C n n − i 2 i =\sum_{i=0}^{n}C_n^{n-i}2^i =i=0∑nCnn−i2i
= ( 1 + 2 ) n = 3 n =(1+2)^n=3^n =(1+2)n=3n
所以复杂度是 O ( 2 n n 3 + 3 n n ) O(2^nn^3+3^nn) O(2nn3+3nn)
#include <iostream>
#include <cstdio>
#include <cstring>
#include <cctype>
#include <algorithm>
using namespace std;
int n,m;
int a[12][2000],mx[2000],pos[2000],mem[2000][1<<12],dp[2000][1<<12];
inline bool cmp(const int& a,const int& b){return mx[a]>mx[b];}
void solve()
{
scanf("%d%d",&n,&m);
memset(mx,0,sizeof(mx));
for (int i=0;i<n;i++)
for (int j=0;j<m;j++)
scanf("%d",&a[i][j]),mx[j]=max(mx[j],a[i][j]);
for (int i=0;i<m;i++) pos[i]=i;
sort(pos,pos+m,cmp);
m=min(n,m);
for (int p=0;p<m;p++)
for (int s=0;s<(1<<n);s++)
{
mem[p][s]=0;
for (int k=0;k<n;k++)
{
int sum=0;
for (int i=0;i<n;i++)
if ((1<<i)&s)
sum+=a[(i+k)%n][pos[p]];
mem[p][s]=max(mem[p][s],sum);
}
}
memset(dp,0,sizeof(dp));
for (int s=0;s<(1<<n);s++) dp[0][s]=mem[0][s];
for (int p=1;p<m;p++)
for (int s=0;s<(1<<n);s++)
{
for (int t=s;t;t=((t-1)&s))
dp[p][s]=max(dp[p][s],dp[p-1][s^t]+mem[p][t]);
dp[p][s]=max(dp[p][s],dp[p-1][s]);
}
printf("%d\n",dp[m-1][(1<<n)-1]);
}
int main()
{
int T;
scanf("%d",&T);
while (T--) solve();
return 0;
}
过于毒瘤