Description
小凸和小方是好朋友,小方给小凸一个N*M(N<=M)的矩阵A,要求小秃从其中选出N个数,其中任意两个数字不能在同一行或同一列,现小凸想知道选出来的N个数中第K大的数字的最小值是多少。
Input
第一行给出三个整数N,M,K
接下来N行,每行M个数字,用来描述这个矩阵
Output
如题
Sample Input
3 4 2
1 5 6 6
8 3 4 3
6 8 6 3
Sample Output
3
HINT
1<=K<=N<=M<=250,1<=矩阵元素<=10^9
分析
二分答案x,判断是否可以去最少n-k+1个数<=x
判断用最大流,行和列都看成点,连边(S,i,1),(i+n,T,1),对于格子(i,j)如果它的数<=x连边(i,j+n,1)
代码
#include <bits/stdc++.h>
const int N = 530;
const int INF = 0x7fffffff / 3;
int read()
{
int x = 0, f = 1;
char ch = getchar();
while (ch < '0' || ch > '9') {if (ch == '-') f = -1; ch = getchar();}
while (ch >= '0' && ch <= '9') {x = x * 10 + ch - '0'; ch = getchar();}
return x * f;
}
struct Edge
{
int to,next,c;
}e[N * N];
int cnt;
int next[N];
void add(int x,int y,int c)
{
e[++cnt].to = y, e[cnt].next = next[x], next[x] = cnt, e[cnt].c = c;
}
void ins(int x,int y,int c)
{
add(x,y,c);
add(y,x,0);
}
int s,t;
int dis[N * N];
bool bfs()
{
for (int i = s; i <= t; i++)
dis[i] = 0;
std::queue <int> Q;
Q.push(s);
dis[s] = 1;
while (!Q.empty())
{
int u = Q.front();
Q.pop();
for (int i = next[u]; i; i = e[i].next)
{
if (e[i].c && !dis[e[i].to])
{
dis[e[i].to] = dis[u] + 1;
if (e[i].to == t)
return 1;
Q.push(e[i].to);
}
}
}
return 0;
}
int cur[N];
int dfs(int x,int maxf)
{
if (x == t || !maxf)
return maxf;
int ret = 0;
for (int &i = cur[x]; i; i = e[i].next)
if (e[i].c && dis[e[i].to] == dis[x] + 1)
{
int f = dfs(e[i].to,std::min(e[i].c,maxf - ret));
ret += f;
e[i].c -= f;
e[i ^ 1].c += f;
if (ret == maxf)
break;
}
return ret;
}
int dinic()
{
int ans = 0;
while (bfs())
{
for (int i = s; i <= t; i++)
cur[i] = next[i];
ans += dfs(s, INF);
}
return ans;
}
int n,m,k;
int a[N][N];
void rebuild(int x)
{
cnt = 1;
for (int i = s; i <= t; i++)
next[i] = 0;
for (int i = 1; i <= n; i++)
ins(s, i, 1);
for (int i = 1; i <= m; i++)
ins(i + n, t, 1);
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
if (a[i][j] <= x)
ins(i, j + n, 1);
}
bool check(int x)
{
rebuild(x);
int ret = dinic();
if (ret >= n - k + 1)
return 1;
return 0;
}
int main()
{
n = read(), m = read(), k = read();
s = 0, t = n + m + 1;
for (int i = 1; i <= n; i++)
for (int j = 1; j <= m; j++)
a[i][j] = read();
int l = 0, r = 1e9 + 7;
while (l < r)
{
int mid = (l + r) >> 1;
if (check(mid))
r = mid;
else l = mid + 1;
}
printf("%d\n",l);
}