终于写完了这道题。
我所用的方法是 DLX,即 Dancing Links X algorithm。这是一个如何的算法呢?其所用即:O(1) 恢复链表,完成搜索剪枝。
接下来看“精确覆盖”问题:
对于一个 01 矩阵,选择若干行,使得矩阵的每一列都有且仅有一个 1。
怎么做?很显然这是 NP 问题,方法只有搜索。而在搜索算法中,有一种专为此而设的算法,即 Dancing Links X。
将 01 矩阵用一个双向循环十字链表阵(即 Dancing links)表示,每一列都有一个列头,第一列的前端链有一个表头,所有元素都有指针指向其所在列的列头。
显然,如果选取一行包含于解中,则该行所在列上的其他元素都不能选。正因为此,既然不能选,那么保留其所在行也没有意义(因为选不全),所以可以一并删去。
所以选取一行,则删去许多。这样便达到了剪枝的效果。
关于链表元素的删除,我这里还有一篇关于搜索链表优化的文章,大概是 Dancing Links 的基础。
接下来说 sudoku 一题。如何将其转化为精确覆盖呢?
显然,每个位置,每个元素对每行,每个元素对每列,每个元素对每个九宫格只能选一次。而选完 81 个数之后,每个数都对应了这些条件。
那么,就可以这样来构造精确覆盖模型:
以每个位置,每个元素对每行,每个元素对每列,每个元素对每个九宫格为列(最多共 324 列);
以所选的每个数为行;
双向循环十字链表中的每一个元素对应上所述。
显然,这个问题就解决了。而且根本不需要加什么搜索剪枝,因为 DLX 可以说是自动剪枝加减少分枝数。
不过比较令人头疼的是, DLX 构造模型时所需要的预处理是相当复杂的,但是这实在没有办法。。。。。
Code :
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <climits>
#include <iostream>
#include <algorithm>
typedef long long int64;
typedef unsigned int uint;
typedef unsigned long long uint64;
#define swap(a, b, t) ({t _ = (a); (a) = (b); (b) = _;})
#define MAX(a, b, t) ({t _ = (a), __ = (b); _ > __ ? _ : __;})
#define MIN(a, b, t) ({t _ = (a), __ = (b); _ < __ ? _ : __;})
#define maintype int
#define max(a, b) (MAX(a, b, maintype))
#define min(a, b) (MIN(a, b, maintype))
#define maxs 12
#define maxn 8005
#define getpos(i, j) (((i) - 1) / 3 * 3 + ((j) - 1) / 3 + 1)
#define abs(a) ({int _____ = a; _____ < 0 ? - _____ : _____;})
#define getcost(i, j) ({int ___ = abs(i - 5), ____ = abs(j - 5); 10 - max(___, ____);})
int tot, ans1, ans2, total, a[maxs][maxs];
struct node{node * l, * r, * u, * d, * c; int s;} vess[maxn];
typedef bool visit[maxs][maxs];
typedef node * sudoku[maxs][maxs];
visit vv, vl, vr, vp;
node * head, * tail;
sudoku mv, ml, mr, mp;
void remove(node * p)
{
p->l->r = p->r, p->r->l = p->l;
for (node * i = p->d; i != p; i = i->d)
for (node * j = i->r; j != i; j = j->r)
j->u->d = j->d, j->d->u = j->u, -- j->c->s;
}
void resume(node * p)
{
for (node * i = p->u; i != p; i = i->u)
for (node * j = i->l; j != i; j = j->l)
j->u->d = j->d->u = j, ++ j->c->s;
p->l->r = p->r->l = p;
}
void dfs(int now, int ans)
{
if (head->r == head)
{
if (now == total) return (void)(ans2 = max(ans2, ans));
return;
}
node * p = head->r;
int mini = p->s;
for (node * i = p->r; i != head; i = i->r)
if (i->s < mini)
mini = i->s, p = i;
remove(p);
for (node * i = p->d; i != p; i = i->d)
{
for (node * j = i->r; j != i; j = j->r)
remove(j->c);
dfs(now + 1, ans + i->s);
for (node * j = i->l; j != i; j = j->l)
resume(j->c);
}
resume(p);
}
node * newnode(node * tar, int s)
{
node * p = vess + ++ tot;
p->u = tar->u, tar->u->d = p;
p->d = tar, tar->u = p;
p->l = p->r = p, p->c = tar;
return ++ tar->s, p->s = s, p;
}
node * newhead()
{
tail->r = vess + ++ tot;
tail->r->l = tail, tail = tail->r;
tail->u = tail->d = tail->c = tail;
return tail;
}
void prepare()
{
head = tail = vess;
head->u = head->d = head->c = head;
for (int i = 1; i <= 9; ++ i)
for (int j = 1; j <= 9; ++ j)
if (not vv[i][j])
mv[i][j] = newhead();
for (int i = 1; i <= 9; ++ i)
for (int j = 1; j <= 9; ++ j)
if (not vl[i][j])
ml[i][j] = newhead();
for (int i = 1; i <= 9; ++ i)
for (int j = 1; j <= 9; ++ j)
if (not vr[i][j])
mr[i][j] = newhead();
for (int i = 1; i <= 9; ++ i)
for (int j = 1; j <= 9; ++ j)
if (not vp[i][j])
mp[i][j] = newhead();
tail->r = head, head->l = tail;
for (int k = 1; k <= 9; ++ k)
for (int i = 1; i <= 9; ++ i)
for (int j = 1; j <= 9; ++ j)
{
if (vv[i][j] or vl[k][i] or vr[k][j]) continue;
int P = getpos(i, j), s = k * getcost(i, j);
if (vp[k][P]) continue;
node * o = newnode(mv[i][j], s);
node * p = newnode(ml[k][i], s);
node * q = newnode(mr[k][j], s);
node * r = newnode(mp[k][P], s);
o->r = p, p->r = q, q->r = r, r->r = o;
r->l = q, q->l = p, p->l = o, o->l = r;
}
}
int main()
{
freopen("sudoku.in", "r", stdin);
freopen("sudoku.out", "w", stdout);
for (int i = 1; i <= 9; ++ i)
for (int j = 1; j <= 9; ++ j)
if (scanf("%d", & a[i][j]), a[i][j])
{
int k = a[i][j], p = getpos(i, j), c = getcost(i, j);
ans1 += k * c, total += vv[i][j] = vl[k][i] = vr[k][j] = vp[k][p] = 1;
}
prepare(), total = 81 - total, dfs(0, 0);
ans2 ? printf("%d\n", ans1 + ans2) : puts("-1");
return 0;
}