判断
二分图是无向图的一种,它只对无向图中的边做了限制。将图中的所有定点分到两个集合中,要求所有边的两个端点分别处于这两个集合。
解题思路也很简单。遍历每个点,并从每个点出发找它的相邻的点,对处于不同集合的点做不同得标记。如果发现标记产生冲突,那么就不是二分图;如果成功给每个点做上标记,则是二分图。
leetcode有一题判断二分图的题,链接
class Solution {
private:
constexpr static int nocolor = 0;
constexpr static int red = 1;
constexpr static int green = 2;
vector<int> color_;
bool dfs(int node, int c, const vector<vector<int>>& graph) {
color_[node] = c; // 对node节点染色
int c_next = c == red ? green : red; // 下一个节点的颜色
bool ret = true; // 默认情况下,如果一个节点没有邻接点,那就是成功的
for (auto next : graph[node]) {
if (color_[next] == nocolor) {
// 如果邻接节点没有被染色,那么尝试染色
ret = dfs(next, c_next, graph);
if (!ret) break;
} else if (color_[next] == c_next) {
// 如果邻接节点已经被染色了,而且颜色正确,那就正常,可以继续循环
ret = true;
} else {
// 如果邻接节点已经被染色了,而且颜色错误,那就失败,不能继续循环,直接退出
ret = false;
break;
}
}
return ret;
}
public:
bool isBipartite(vector<vector<int>>& graph) {
int n = graph.size();
color_.resize(n, nocolor);
for (int i = 0; i < n; ++i) {
// 遍历所有节点,尝试对每个节点染色
if (color_[i] == nocolor) {
// 如果该节点没有被染色,则染红色
if (!dfs(i, red, graph)) return false;
}
}
return true;
}
};
匹配
关于图的匹配,OI Wiki上有比较详细的介绍,参考链接,其中包括了最大匹配和最大权匹配。
UOJ上有两题相应地模板题可以测试代码:
最大匹配
匈牙利算法可以用来求最大匹配问题。要在一个二分图中找最多的匹配对,其最核心的思想就是:
找增广路。将增广路中的匹配边与非匹配边互换,就会多一条匹配边。
因为图是没有方向的,所以我们只要尽可能地从左图中的点往右图匹配就行。(这里用左图和右图来区分二分图中的两部分点,懒得画图了,请脑补)
知道增广路的理论后,最朴素的想法是,保证二分图的左图中的所有点都找不到增广路就是实现了最大匹配。
但更简单的结论是,只需要每个点遍历一次,就可以确保找到最大匹配,而不需要反复去确认每个点是否找不到增广路。
在遍历每个点的时候就只有两种结果。一:如果能匹配上,那就直接匹配就行。二:如果匹配不上,就尝试找增广路。找到了就是匹配成功,找不到增广路,后面其它点不管再怎么匹配,都改变不了这个结果,就是匹配不上。
用深度优先遍历(dfs)写最大匹配会比较方便,因为在确定找到增广路的时候需要回溯,dfs的递归调用就自带回溯,很容易实现增广路中的匹配边和非匹配边交换的目的。
#include <cstring>
#include <iostream>
#include <vector>
class Solution {
private:
std::vector<std::vector<int>> graph_; // 从左侧到右侧的邻接表
int m_; // 左侧节点数
int n_; // 右侧节点数
std::vector<int> m2n_; // 从左侧到右侧的匹配结果
std::vector<int> n2m_; // 从右侧到左侧的匹配结果
std::vector<int> visited_; // 表示左侧被访问过的节点
/**
* @brief 左侧的节点m尝试匹配一个右侧的节点n
*
* @param m 左侧节点索引值
* @return true - 匹配成功
* @return false - 匹配失败
*/
bool dfs(int m) {
visited_[m] = 1;
for (int n : graph_[m]) {
// 便利从m节点到右侧的所有邻接节点
if (n2m_[n] == -1) {
// 如果右侧节点还没有匹配的对象
m2n_[m] = n;
n2m_[n] = m;
return true;
} else {
// 如果右侧节点已经有匹配的对象了,就看看它匹配的对象有没有访问过,再进一步找增广路径
if (!visited_[n2m_[n]] && dfs(n2m_[n])) {
m2n_[m] = n;
n2m_[n] = m;
return true;
}
}
}
return false;
}
public:
Solution(int m, int n) : m_(m), n_(n) {
graph_.resize(m);
m2n_.resize(m, -1);
n2m_.resize(n, -1);
visited_.resize(m, 0);
}
/**
* @brief 向邻接表中添加边
*
* @param from - 左侧节点索引
* @param to - 右侧节点索引
* @return true - 添加成功
* @return false - 添加失败,索引值超出范围
*/
bool add_edge(int from, int to) {
if (from < 0 || from >= m_ || to < 0 || to >= n_) return false;
graph_[from].push_back(to);
return true;
}
/**
* @brief 搜索最大匹配
*
*/
void run() {
// 只需要一次遍历就可以确保已经找到了所有增广路径
for (int i = 0; i < m_; ++i) {
// 将每个节点都标记为未访问
memset(visited_.data(), 0, sizeof(int) * visited_.size());
dfs(i);
}
// 输出匹配结果
int match_cnt = 1;
for (int i = 0; i < m2n_.size(); ++i) {
if (m2n_[i] != -1) {
std::cout << "match " << match_cnt++ << ": " << i << " " << m2n_[i] << "\n";
}
}
}
};
int main(int argc, char *argv[]) {
int m, n, e;
std::cin >> m >> n >> e;
Solution solution(m, n);
// 初始化邻接表
for (int i = 0; i < e; ++i) {
int u, v;
std::cin >> u >> v;
solution.add_edge(u, v);
}
// 搜索最大匹配
solution.run();
return 0;
}
最大权匹配
二分图中每条边都带有权重,找到权重之和最大的匹配结果。
KM算法(Kuhn–Munkres Algorithm)可以用来求完备匹配下的最大权匹配,即要求二分图中左图和右图中节点个数相同。如果遇到二分图中左右节点个数不同得情况,可以手动补齐少的那部分。最后输出的时候把那些权重为0的边去掉就行。
如果延续匈牙利算法,在没有了解过KM算法的情况下,可以有这样的思路:
如果同样从左侧的点往右侧匹配。先把每个左侧的点权重最大的边拎出来,尝试匹配,如果都匹配成功了,那自然就找到了答案,因为不可能有更大的权重之和了。
但一般情况下,不会那么顺利一下子就成功,因此需要一个策略,找出一条对总权重影响最小的边添加到已经拎出来的边中,再尝试去匹配。我觉得这就是KM算法的核心思想。
主要还是参考OI Wiki中的相关介绍,如果dfs搜索成功,那可能是得到了一个增广路,也有可能是直接匹配成功;如果dfs搜索失败,得到的就是一颗交错树。在这颗交错树中,位于左图中的点集记为S,位于右图中的点记为T;不在交错树中,位于左图中的点记为S‘,位于右图中的点记为T‘。
要从T’中找一个点,来扩展这颗交错树,从而让交错树有可能可以变成增广路。KM算法巧妙地把边的权重分散到顶点上,点的权重之和表示的是最大权匹配的极限值。然后再要求边的权重只有与两边顶点权重之和相同才能成为“可见的边”。基于这样的设定,通过计算可以找出一个权重调整的最小值,让一条新的边加入进来,在加入新的边的同时,将权重的损失降低到最小,非常符合直觉。
回到代码中,在最大匹配的基础上,我们要做的就是记录下某一次dfs遍历形成的交错树。为了让交错树能够扩展开去,找一条新的从S到T’的边。
// 如果没有匹配上,就需要先计算最小的slack
int slack = INT32_MAX;
for (int j = 0; j < sz_; ++j) {
if (!vis_n_[j]) {
// T’中的点
for (int k = 0; k < sz_; ++k) {
if (vis_m_[k]) {
// S中的点
slack = std::min(slack, mw_[k] + nw_[j] - graph_[k][j]);
}
}
}
}
完整的dfs版本的代码如下:
#include <cstring>
#include <iostream>
#include <vector>
class Solution {
private:
int m_; /* 左侧节点数 */
int n_; /* 右侧节点数 */
int sz_; /* m_和n_中的较大值 */
std::vector<std::vector<int>> graph_;
std::vector<int> mw_;
std::vector<int> nw_;
std::vector<int> vis_m_;
std::vector<int> vis_n_;
std::vector<int> m2n_;
std::vector<int> n2m_;
/**
* @brief 左侧的节点m尝试匹配一个右侧的节点n
*
* @param m 左侧节点索引值
* @return true - 匹配成功
* @return false - 匹配失败
*/
bool dfs(int m) {
vis_m_[m] = 1;
for (int i = 0; i < sz_; ++i) {
if (mw_[m] + nw_[i] == graph_[m][i]) {
// 存在边
if (n2m_[i] == -1) {
// 如果右侧节点还没有匹配的对象
m2n_[m] = i;
n2m_[i] = m;
// vis_n_[i] = 1; 成功匹配上的时候就不关心有没有visit了
return true;
} else {
// 如果右侧节点已经有匹配的对象了
if (!vis_m_[n2m_[i]]) {
// 右侧节点的左侧匹配对象没有被访问过,也就是还没有加入到交错树中
vis_n_[i] = 1;
vis_m_[n2m_[i]] = 1;
if (dfs(n2m_[i])) {
// 如果成功找到增广路,那么就把非匹配边变成匹配边
m2n_[m] = i;
n2m_[i] = m;
return true;
}
}
}
}
}
return false;
}
void output() {
long long sum_weight = 0;
for (int i = 0; i < m_; ++i) {
if (graph_[i][m2n_[i]] > 0) {
sum_weight += graph_[i][m2n_[i]];
} else {
m2n_[i] = -1;
}
}
std::cout << sum_weight << "\n";
for (int i = 0; i < m_; ++i) {
std::cout << m2n_[i] + 1 << " ";
}
std::cout << "\n";
}
public:
Solution(int m, int n) : m_(m), n_(n) {
sz_ = std::max(m, n);
graph_.resize(sz_, std::vector<int>(sz_, 0));
mw_.resize(sz_, 0);
nw_.resize(sz_, 0);
vis_m_.resize(sz_, 0);
vis_n_.resize(sz_, 0);
m2n_.resize(sz_, -1);
n2m_.resize(sz_, -1);
}
/**
* @brief 向邻接表中添加边
*
* @param from - 左侧节点索引
* @param to - 右侧节点索引
* @param weight - 边的权重
* @return true - 添加成功
* @return false - 添加失败,索引值超出范围
*/
bool add_edge(int from, int to, int weight) {
if (from < 0 || from >= m_ || to < 0 || to >= n_) return false;
graph_[from][to] = weight;
return true;
}
/**
* @brief 搜索最大权匹配
*
*/
void run() {
// 初始化mw和nw
memset(nw_.data(), 0, sizeof(int) * sz_);
for (int i = 0; i < sz_; ++i) {
for (int j = 0; j < sz_; ++j) {
mw_[i] = std::max(mw_[i], graph_[i][j]);
}
}
for (int i = 0; i < m_; ++i) {
for (;;) {
// visit标记的点只有在dfs搜索失败的时候用得上,这些点就是组成交错树的点
// 需要在每次dfs搜索的最开始时候清空
memset(vis_m_.data(), 0, sizeof(int) * sz_);
memset(vis_n_.data(), 0, sizeof(int) * sz_);
if (dfs(i)) {
break;
}
// 如果没有匹配上,就需要先计算最小的slack
int slack = INT32_MAX;
for (int j = 0; j < sz_; ++j) {
if (!vis_n_[j]) {
for (int k = 0; k < sz_; ++k) {
if (vis_m_[k]) {
slack = std::min(slack, mw_[k] + nw_[j] - graph_[k][j]);
}
}
}
}
// 再调整mw和nw
for (int j = 0; j < sz_; ++j) {
if (vis_m_[j]) {
mw_[j] -= slack;
}
if (vis_n_[j]) {
nw_[j] += slack;
}
}
}
}
output();
}
};
int main(int argc, char *argv[]) {
int m, n, e;
std::cin >> m >> n >> e;
Solution solution(m, n);
// 初始化邻接表
for (int i = 0; i < e; ++i) {
int u, v, w;
std::cin >> u >> v >> w;
solution.add_edge(u - 1, v - 1, w);
}
solution.run();
return 0;
}
优化
这样的计算方式在遇到点或者边的数量非常多的情况下效率会很低。因为可以考虑一下这种情况,每次调整,我们都往交错树中新引入一条边,然后就需要重新用dfs重头搜索,比如某个左图中的点有400条边,而这400条边中,只有权重最小的那条边才能增广,这样我们就需要调整400次,然后搜索400次才能搜到一条增广路。再如果不是400,是更大的数,这个耗时就非常多。UOJ上的测试题如果用上面的代码去跑,就会出现TLE。
dfs的问题在于它需要重头搜索,这显然不太合理,因为一颗交错树我已经完整搜索得到了才确认了它无法增广,这时候新增一条边,我们完全可以在这颗交错树的基础上再去搜索,而不需要重头再搜索一遍这颗交错树。
dfs在处理最大匹配的时候非常方便,因为递归调用的过程自然地就把回溯增广路的过程实现了。而到了最大权匹配的问题中,如果想要记录这颗交错树,就需要一些额外的信息,比如交错树中的非匹配边。而广度优先搜索(bfs)本身就是需要记录回溯的信息,所以不如直接用bfs来实现,以下参考了OI Wiki中的参考代码。
交错树中的匹配边已经由m2n_和n2m_记录了,而非匹配边是用pre_数组记录的。而且pre_数组还有一个重要的作用就是记录一个新的右图顶点的潜在匹配目标。
如下代码,在bfs遍历的过程中,右图中的顶点i,如果有一个更小的slack,就更新pre_[i],即记录i顶点最优先和左图中的哪个顶点匹配。这里直接就算好了如果有一个T’中的点之后要加入进来,那它会和哪个左图中的点匹配,省去了后续重新搜索交错树的时间。
if (!vis_n_[i]) {
// 没有访问过
int delta = mw_[node_m] + nw_[i] - graph_[node_m][i];
if (slack_[i] >= delta) {
// i有一个潜在的非匹配边
pre_[i] = node_m;
if (delta) {
slack_[i] = delta;
} else if (check(i)) {
return;
}
}
}
完整的代码如下:
#include <cstring>
#include <iostream>
#include <queue>
#include <vector>
class Solution {
private:
int m_; /* 左侧节点数 */
int n_; /* 右侧节点数 */
int sz_; /* m_和n_中的较大值 */
std::vector<std::vector<int>> graph_;
std::vector<int> mw_;
std::vector<int> nw_;
std::vector<int> vis_m_;
std::vector<int> vis_n_;
std::vector<int> m2n_;
std::vector<int> n2m_;
std::vector<int> pre_; // 用来记录非匹配边从n到m的匹配关系
std::queue<int> q_;
std::vector<int> slack_;
bool check(int n) {
vis_n_[n] = 1;
if (n2m_[n] == -1) {
// i节点是未匹配的节点
n2m_[n] = pre_[n];
// 回溯,增广路中的匹配边和非匹配边调换
int empty_node_n = m2n_[pre_[n]];
while (empty_node_n != -1) {
int tmp = m2n_[pre_[empty_node_n]];
m2n_[pre_[empty_node_n]] = empty_node_n;
n2m_[empty_node_n] = pre_[empty_node_n];
empty_node_n = tmp;
}
m2n_[pre_[n]] = n;
return true;
} else {
// 这个节点已经匹配了
q_.push(n2m_[n]);
return false;
}
}
void bfs(int m) {
while (!q_.empty()) q_.pop();
q_.push(m);
for (;;) {
while (!q_.empty()) {
int node_m = q_.front();
q_.pop();
vis_m_[node_m] = 1;
for (int i = 0; i < sz_; ++i) {
if (!vis_n_[i]) {
// 没有访问过
int delta = mw_[node_m] + nw_[i] - graph_[node_m][i];
if (slack_[i] >= delta) {
// i有一个潜在的非匹配边
pre_[i] = node_m;
if (delta) {
slack_[i] = delta;
} else if (check(i)) {
return;
}
}
}
}
}
// 如果没有匹配上,就需要先计算最小的slack
int slack = INT32_MAX;
for (int j = 0; j < sz_; ++j) {
if (!vis_n_[j]) {
slack = std::min(slack, slack_[j]);
}
}
// 再调整mw和nw
for (int j = 0; j < sz_; ++j) {
if (vis_m_[j]) {
mw_[j] -= slack;
}
if (vis_n_[j]) {
nw_[j] += slack;
} else {
slack_[j] -= slack;
}
}
// 通过新增加的点去找增广路
for (int j = 0; j < sz_; ++j) {
if (!vis_n_[j] && slack_[j] == 0 && check(j)) {
// 找到增广路
return;
}
// 如果没有找到增广路,即新加入的节点是已经有匹配了的,那就需要重新找新的节点
}
}
return;
}
void output() {
long long sum_weight = 0;
for (int i = 0; i < m_; ++i) {
if (graph_[i][m2n_[i]] > 0) {
sum_weight += graph_[i][m2n_[i]];
} else {
m2n_[i] = -1;
}
}
std::cout << sum_weight << "\n";
for (int i = 0; i < m_; ++i) {
std::cout << m2n_[i] + 1 << " ";
}
std::cout << "\n";
}
public:
Solution(int m, int n) : m_(m), n_(n) {
sz_ = std::max(m, n);
graph_.resize(sz_, std::vector<int>(sz_, 0));
mw_.resize(sz_, 0);
nw_.resize(sz_, 0);
vis_m_.resize(sz_, 0);
vis_n_.resize(sz_, 0);
m2n_.resize(sz_, -1);
n2m_.resize(sz_, -1);
pre_.resize(sz_, -1);
slack_.resize(sz_, INT32_MAX);
}
/**
* @brief 向邻接表中添加边
*
* @param from - 左侧节点索引
* @param to - 右侧节点索引
* @param weight - 边的权重
* @return true - 添加成功
* @return false - 添加失败,索引值超出范围
*/
bool add_edge(int from, int to, int weight) {
if (from < 0 || from >= m_ || to < 0 || to >= n_) return false;
graph_[from][to] = weight;
return true;
}
/**
* @brief 搜索最大权匹配
*
*/
void run() {
// 初始化mw和nw
memset(nw_.data(), 0, sizeof(int) * sz_);
for (int i = 0; i < sz_; ++i) {
for (int j = 0; j < sz_; ++j) {
mw_[i] = std::max(mw_[i], graph_[i][j]);
}
}
for (int i = 0; i < m_; ++i) {
memset(vis_m_.data(), 0, sizeof(int) * sz_);
memset(vis_n_.data(), 0, sizeof(int) * sz_);
memset(slack_.data(), 0x3f, sizeof(int) * sz_);
bfs(i);
}
output();
}
};
int main(int argc, char *argv[]) {
int m, n, e;
std::cin >> m >> n >> e;
Solution solution(m, n);
// 初始化邻接表
for (int i = 0; i < e; ++i) {
int u, v, w;
std::cin >> u >> v >> w;
solution.add_edge(u - 1, v - 1, w);
}
solution.run();
return 0;
}