题意:有一个N*M的01矩阵a,已知每一行有多少个1以及每一列有多少1。现在这个矩阵搞丢了,但是会告诉你a[i,j]是1的概率p[i,j](一个[0,100)的整数表示百分率)。让你还原出一个概率最大的符合条件的01矩阵,任意输出一个。
据说这题只能用zkw费用流过。。这个太奇怪了。
这题调了我一个下午,很有启发意义。
首先这是经典的矩阵还原模型,就是行做X部,列做Y部,之间的连边的流量代表a[i,j]的值,之前做过两道类似的题,因此这里还是想到了的。
然后就是最精华的地方了,概率应该是所有的p[i,j]的乘积,但是费用流的模型是和,我开始试着改一下费用流的写法,就是把中间找最短路的地方改成乘法,这样理论上可行但是乘不到几个数就会严重暴long long了。。然后我就上网搜了,我惊讶地发现竟然没几个人写这道题的题解,并且没人贴程序。这题有这么偏门吗。。找了半天才找到,要把费用设为概率的对数!这个真是太巧妙了,因为两个正数的大小关系完全等同于他们取对数的大小,然后两个数相乘就对应了他们的对数相加!!!然后要注意因为是最大费用,将他们费用全部取反,最后再取反回来即可。
然后我就无脑地写了一发,然后就交了,结果只过了前四个点,后面全T了。。然后我就各种调,我开始以为是连边有问题产生了负权环,然后下了组数据写了个SPFA判定,结果没负环。然后我就翻来覆去看了几遍费用流的模板,觉得没问题,,然后就抓狂了,在各种细节上试,,最后终于试出来了,就是update里面判是否是最短路径上的点那个dis的判定我直接用的等号,但事实上由于精度问题转移几次之后该相等的已经不相等了,于是我写了个eps和自定义了个等号。。好消息,没T了,但是都WA了,,,然后我又找了半天,觉得好像取对数的时候精度不好,于是把概率全部乘了很大一个数之后再取的对数,然后终于过了。以前openjudge上面有道题就是,要先乘一个数再做除法,不然精度要爆。
这题没看到别人写的代码,也完全不知道精度会有这种问题,全是自己摸索的,写出来还是很愉快。
这道题的启发:
1、比较乘积但是太大了,可以换成取对数之后的加法。
2、像什么取对数,开根号之类容易挂精度的,可以先乘一个很大的数。
3、系统的对数函数好像并不是O(1)的(或者是常数有点大)。系统自带的有log,log2,log10,分别是以e,2,10为底,其中以2位底的会比另外两个快三倍,如果只是用来比大小,最后用log2。
#include<cstdio>
#include<cstring>
#include<cassert>
#include<cmath>
#include<algorithm>
using namespace std;
#define DB double
#define clr(a) memset(a,0,sizeof a)
inline DB min(const DB&a, const DB&b)
{ return a < b ? a : b; }
#define rep(a,b,c) for (int a=b;a<=c;++a)
const int inf = 0x3f3f3f3f;
const int MAXN = 205;
const int MAXM = 100000;
int N, M;
const DB eps = 1e-10;
bool cmp(DB a, DB b) //自己定义在精度误差范围内的等号
{
if (a < b) swap(a, b);
return a-b <= eps;
}
struct Ed {
int to, cap;
DB cost;
Ed*nxt, *back;
};
struct FlowNet
{
Ed Edge[MAXM], *ecnt, *adj[MAXN];
FlowNet () { ecnt=Edge; }
DB dis[MAXN];
bool vis[MAXN];
int vn, S, T, flow;
DB tot;
inline void adde(int a, int b, int c, DB d)
{
(++ecnt)->to = b;
ecnt->cap = c;
ecnt->cost = d;
ecnt->nxt = adj[a];
ecnt->back = ecnt+1;
adj[a] = ecnt;
(++ecnt)->to = a;
ecnt->cap = 0;
ecnt->cost = -d;
ecnt->nxt = adj[b];
ecnt->back = ecnt-1;
adj[b] = ecnt;
}
void init(int n, int s, int t)
{
tot = 0; flow = 0;
vn = n; S = s; T = t;
clr(dis), clr(adj);
ecnt = Edge;
}
bool update()
{
DB tmp = 1e10;
rep(i, 1, vn) if (vis[i])
for (Ed *p = adj[i]; p; p=p->nxt)
if (p->cap > 0 && !vis[p->to])
tmp = min(tmp, dis[p->to]-dis[i] + p->cost);
if (tmp == 1e10) return 0;
for (int i = 1; i<=vn; ++i)
if (vis[i]) dis[i] += tmp;
return 1;
}
int aug(int u, int augco)
{
if (u == T)
{
tot += dis[S] * augco;
flow += augco;
return augco;
}
vis[u] = 1;
int delta, augc = augco;
for (Ed*p = adj[u]; p && augc; p=p->nxt)
{
int&v = p->to;
if (!vis[v] && p->cap && cmp(dis[u],dis[v]+p->cost)) //写等号要出问题
{
delta = min(p->cap, augc);
delta = aug(v, delta);
p->cap -= delta, p->back->cap += delta;
augc -= delta;
}
}
return augco - augc;
}
DB mcmf()
{
do {
do clr(vis); while (aug(S, inf));
} while (update());
return tot;
}
void drawMap(int img[105][105])
{
rep(i, 1, N)
for (Ed*p = adj[i]; p; p=p->nxt)
if (p->to > i && p->to<=N+M)
img[i][p->to - N] = p->back->cap;
}
} G;
int img[105][105];
int main()
{
scanf("%d%d", &N, &M);
G.init(N+M+2, N+M+1, N+M+2);
DB tmp;
rep(i, 1, N) rep(j, 1, M)
{
scanf("%lf", &tmp);
G.adde(i, j+N, 1, -log2(tmp*1e6)); //调整精度
}
int xx;
rep(i, 1, N)
{
scanf("%d", &xx);
G.adde(G.S, i, xx, 0);
}
rep(j, 1, M)
{
scanf("%d", &xx);
G.adde(N+j, G.T, xx, 0);
}
G.mcmf();
G.drawMap(img);
rep(i, 1, N)
{
rep(j, 1, M) printf("%d", img[i][j]);
puts("");
}
return 0;
}