解题思路
我们把所有最小矩阵扔进堆里,每次取堆顶更新即可
代码
暴力
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<iomanip>
#include<cstring>
#include<cmath>
#include<map>
#include<queue>
#define ll long long
#define ldb long double
using namespace std;
int n,m,mina,minb,d,k,w;
int a[1010][1010],b[1010][1010];
int main(){
priority_queue<int, vector<int>, greater<int> > p;
scanf("%d%d%d%d%d",&n,&m,&mina,&minb,&k);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
{
scanf("%d",&a[i][j]);
b[i][j]=b[i-1][j]+b[i][j-1]-b[i-1][j-1]+a[i][j];
}
}
// d=min(mina,minb);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
{
for(int x=i;x<=n;x++)
{
for(int y=j;y<=m;y++)
{
int xx=x-i+1,yy=y-j+1;
if(xx>=mina&&yy>=minb)
{
w=b[x][y]-b[i-1][y]-b[x][j-1]+b[i-1][j-1];
p.push(w);
}
}
}
}
}
while(k>1&&!p.empty())
{
p.pop();
k--;
}
if(k>1)
{
printf("-1");
return 0;
}
printf("%d",p.top());
}
正解
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<iomanip>
#include<cstring>
#include<cmath>
#include<map>
#include<queue>
#include <set>
#define ll long long
#define ldb long double
using namespace std;
int n,m,mina,minb,d,k,w,tot;
int a[1010][1010],b[1010][1010];
int read() {
int x=0,v=1; char ch=getchar();
for (;ch<'0'||ch>'9';v=(ch=='-')?(-1):(v),ch=getchar());
for (;ch<='9'&&ch>='0';x=x*10+ch-'0',ch=getchar());
return x*v;
}
struct data{
int x1,y1,x2,y2;
ll sum;
bool operator<(data b)const{
data a=*this;
if(a.sum!=b.sum)return a.sum<b.sum;
if(a.x1!=b.x1)return a.x1<b.x1;
if(a.y1!=b.y1)return a.y1<b.y1;
if(a.x2!=b.x2)return a.x2<b.x2;
if(a.y2!=b.y2)return a.y2<b.y2;
return 0;
}
}top,tmp;
std:: set<data> heap;
void get_sum(data &a){
a.sum=b[a.x2][a.y2]-b[a.x1-1][a.y2]-b[a.x2][a.y1-1]+b[a.x1-1][a.y1-1];
}
int main(){
scanf("%d%d%d%d%d",&n,&m,&mina,&minb,&k);
for(int i=1;i<=n;i++)
{
for(int j=1;j<=m;j++)
{
a[i][j]=read();
b[i][j]=b[i-1][j]+b[i][j-1]-b[i-1][j-1]+a[i][j];
}
}
for(int i=1;i<=n-mina+1;i++)
{
for(int j=1;j<=m-minb+1;j++)
{
data tmp=(data){i,j,i+mina-1,j+minb-1};
get_sum(tmp);heap.insert(tmp);
}
}
for(int i=mina;i<=n;i++)
{
for(int j=minb;j<=m;j++)
tot+=1ll*(n-i+1)*(m-j+1);
}
if(tot<=k)
{
printf("-1");
return 0;
}
for(int i=k;i>1;i--)
{
top=*heap.begin();heap.erase(top);
if(top.x1>1){
tmp=(data){top.x1-1,top.y1,top.x2,top.y2};
get_sum(tmp);heap.insert(tmp);
}
if(top.y1>1){
tmp=(data){top.x1,top.y1-1,top.x2,top.y2};
get_sum(tmp);heap.insert(tmp);
}
if(top.x2<n){
tmp=(data){top.x1,top.y1,top.x2+1,top.y2};
get_sum(tmp);heap.insert(tmp);
}
if(top.y2<m){
tmp=(data){top.x1,top.y1,top.x2,top.y2+1};
get_sum(tmp);heap.insert(tmp);
}
}
data top=*heap.begin(); get_sum(top);
printf("%u\n", top.sum);
}