题目链接:
http://acm.hdu.edu.cn/showproblem.php?pid=3311
题目大意:
现在有n个寺庙和m个城市,
在这个n + m个城市之间有p条道路。
在道路上修建运输的管道有一定的花费,在某个城市挖井也有一定的花费。
求使得n个寺庙都能喝到水的最小花费
算法:
斯坦纳树模板题。
斯坦纳树指的是,包含图上某个特定点集的最小生成树。
求斯坦纳树的方法是比较常规的状态压缩DP。
dp[mask][i]代表以i为根,点集覆盖情况为mask的生成树的最小费用。
转移分为两种。
1) dp[mask][i] = min(dp[mask][i] , dp[mask][j] + dis[i][j]),这个更新显然要跑spfa
2) dp[mask][i] = min(dp[mask][i] , dp[mask1][i] + dp[mask2][i]),mask1是mask的真子集,当i不是目标点集的点时,mask2 = mask ^ mask1,当i是目标点集的点时,mask2 = mask ^ mask1 ^ (1 << sig[i]),sig[i]是点i在状压中代表的是第几位。
如果题目结果是某种形式的森林,最后还要状压背包一下
使用for (int mask1 = (mask - 1) & mask; mask1; mask1 = (mask1 - 1) & mask)枚举子集可以避免不必要的重复
本题有在本城市挖井和连向有井的城市两种选择,
可以设立一个虚拟节点0,虚拟节点到其它节点的距离就是挖井的费用,连向虚拟节点就等同于在这个城市挖井。
因为n >= 1,所以最后只要所有城市都与虚拟节点联通即可。
最后结果是dp[(1 << n + 1) - 1][0]
代码:
#include <cstdio>
#include <iostream>
#include <algorithm>
#include <sstream>
#include <cstdlib>
#include <cstring>
#include <string>
#include <climits>
#include <cmath>
#include <queue>
#include <vector>
#include <stack>
#include <set>
#include <map>
#define INF 0x3f3f3f3f
#define eps 1e-8
using namespace std;
const int MAXN = 1100;
const int MAXM = 15000;
queue <int> q;
int dp[1 << 6][MAXN];
bool inq[MAXN];
int head[MAXN], to[MAXM], w[MAXM], nxt[MAXM];
int E;
void _addedge(int u, int v, int cst)
{
to[E] = v;
nxt[E] = head[u];
w[E] = cst;
head[u] = E ++;
}
void addedge(int u, int v, int c)
{
_addedge(u, v, c);
_addedge(v, u, c);
}
int main()
{
int n, m, p;
while (scanf("%d %d %d", &n, &m, &p) == 3)
{
memset(dp, -1, sizeof(dp));
memset(head, -1, sizeof(head));
E = 0;
for (int i = 1; i <= n + m; i ++)
{
int c;
scanf("%d", &c);
addedge(0, i, c);
}
while (p --)
{
int u, v, c;
scanf("%d %d %d", &u, &v, &c);
addedge(u, v, c);
}
for (int i = 0; i <= n; i ++)
{
dp[1 << i][i] = 0;
}
for (int i = n + 1; i <= n + m; i ++)
{
dp[0][i] = 0;
}
for (int mask = 0; mask < 1 << n + 1; mask ++)
{
for(int i = 0; i <= n + m; i ++)
{
if(i <= n && ! (mask >> i & 1))
{
continue;
}
for (int mask1 = (mask - 1) & mask; mask1; mask1 = (mask1 - 1) & mask)
{
int mask2 = mask ^ mask1 | (i <= n ? 1 << i : 0);
if (dp[mask1][i] != -1 && dp[mask2][i] != -1)
{
dp[mask][i] = dp[mask][i] == -1 ? dp[mask2][i] + dp[mask1][i] : min(dp[mask][i], dp[mask2][i] + dp[mask1][i]);
}
}
}
while(! q.empty())
{
q.pop();
}
for(int i = 0; i <= n + m; i ++)
{
if(dp[mask][i] != -1)
{
q.push(i);
inq[i] = true;
}
else
{
inq[i] = false;
}
}
while(! q.empty())
{
int u = q. front();
q.pop();
inq[u] = false;
for (int filter = head[u]; filter != -1; filter = nxt[filter])
{
int v = to[filter];
if(dp[mask][v] == - 1 || dp[mask][v] > dp[mask][u] + w[filter])
{
dp[mask][v] = dp[mask][u] + w[filter];
if(! inq[v])
{
q.push(v);
inq[v] = true;
}
}
}
}
}
printf("%d\n", dp[(1 << n + 1) - 1][0]);
}
return 0;
}