Description
小凸和小方是好朋友,小方给小凸一个N*M(N<=M)的矩阵A,要求小秃从其中选出N个数,其中任意两个数字不能在同一行或同一列,现小凸想知道选出来的N个数中第K大的数字的最小值是多少。
Input
第一行给出三个整数N,M,K
接下来N行,每行M个数字,用来描述这个矩阵
Output
如题
Sample Input
3 4 2
1 5 6 6
8 3 4 3
6 8 6 3
Sample Output
3
HINT
1<=K<=N<=M<=250,1<=矩阵元素<=10^9
Solution
显然是二分答案mid,然后判断能否取出n-k+1个小于等于mid的数。
判断方法就是网络流(蒟蒻一开始根本没想到,还以为是DP,23333)
若a[i][j]小于mid,就把行i与列j连边,然后最大流就为答案。
#include<iostream>
#include<cstring>
#include<cstdio>
#include<algorithm>
#include<cmath>
#include<queue>
using namespace std;
#define N 300
#define INF 0x7fffffff
struct edge
{
int y,next,w;
};
edge side[N*N*2+4*N];
int a[N][N],p[N*N],last[N+N],dis[N+N];
int s,t,n,m,k,l,ans,tot,an=INF;
bool comp(int x,int y)
{
return x<y;
}
void init()
{
scanf("%d%d%d",&n,&m,&k);
for (int i=1;i<=n;i++)
for (int j=1;j<=m;j++)
{
scanf("%d",&a[i][j]);
int pos=(i-1)*m+j;
p[pos]=a[i][j];
}
sort(p+1,p+n*m+1,comp);
l=unique(p+1,p+m*n+1)-p-1;
}
void add(int x,int y,int w)
{
tot++;
side[tot].y=y;
side[tot].w=w;
side[tot].next=last[x];
last[x]=tot;
}
void addside(int x,int y)
{
add(x,y,1);
add(y,x,0);
}
void build(int x)
{
tot=1;
memset(last,0,sizeof(last));
memset(side,0,sizeof(side));
s=0; t=n+m+1;
for (int i=1;i<=n;i++)
for (int j=1;j<=m;j++)
if (a[i][j]<=x)
addside(i,j+n);
for (int i=1;i<=n;i++)
addside(s,i);
for (int i=n+1;i<=n+m;i++)
addside(i,t);
}
queue<int> Q;
bool bfs()
{
memset(dis,0,sizeof(dis));
dis[s]=1;
Q.push(s);
while (!Q.empty())
{
int now=Q.front(); Q.pop();
for (int i=last[now];i!=0;i=side[i].next)
if (side[i].w>0)
{
int j=side[i].y;
if (dis[j]==0)
{
dis[j]=dis[now]+1;
Q.push(j);
}
}
}
if (dis[t]!=0) return true;
else return false;
}
int dfs(int x,int maxf)
{
if (x==t || maxf==0) return maxf;
int ret=0;
for (int i=last[x];i!=0;i=side[i].next)
if (side[i].w>0 && dis[x]+1==dis[side[i].y])
{
int f=dfs(side[i].y,min(maxf-ret,side[i].w));
side[i].w-=f;
side[i^1].w+=f;
ret+=f;
if (ret==maxf) break;
}
return ret;
}
void dinic()
{
ans=0;
while (bfs()) ans+=dfs(s,INF);
}
int main()
{
init();
int le=1,ri=l,mid;
while (le<=ri)
{
mid=(le+ri)/2;
build(p[mid]);
dinic();
if (ans>=n-k+1) an=min(an,p[mid]),ri=mid-1;
else le=mid+1;
}
printf("%d\n",an);
return 0;
}
ps
没有一次对的原因:放入p数组时pos计算错误(原来是pos=(i-1)*n+j,这样会覆盖一些数的),下次要注意。