并查集基本算法需要:
- 一个数组pre[]来记录当前节点的上一层。
- int unionsearch(int cur)函数,用来查找cur的根节点。
- void join(int x, int y)函数,用来合并x,y两个根节点。
#include <iostream>
#include <cstdio>
#include <map>
#include <string>
#include <vector>
#include <algorithm>
#include <sstream>
#include <cstring>
#include <cmath>
#include <stack>
#include <queue>
#include <set>
using namespace std;
const int maxn = 1 << 8;
int pre[maxn*maxn];
int unionsearch(int cur) {
int root = cur;
// 查找根节点
while(root != pre[root]) root = pre[root];
// 压缩路径,将本连通块所以结点的上级都设为根节点。
while(cur != root) {
int tmp = pre[cur];
pre[cur] = root;
cur = pre[tmp];
}
return root;
}
// 查找+压缩路径的递归写法
int unionsearch2(int cur){
if(cur != pre[cur])
pre[cur] = unionsearch2(pre[cur]); //路径压缩
return pre[cur];
}
void join(int root1, int root2) {
if(root1 != root2) {
// 合并
root1[pre] = root2;
}
}
直接套模板的例题:
例一:
http://lx.lanqiao.cn/problem.page?gpid=T458
#include <iostream>
#include <cstdio>
#include <map>
#include <string>
#include <vector>
#include <algorithm>
#include <sstream>
#include <cstring>
#include <cmath>
#include <stack>
#include <queue>
#include <set>
using namespace std;
const int maxn = 1050;
int m, n;
int pre[maxn*maxn];
int unionsearch(int cur) {
int root = cur;
while(root != pre[root]) root = pre[root];
while(cur != root) {
int tmp = pre[cur];
pre[cur] = root;
cur = pre[tmp];
}
return root;
}
int main() {
freopen("i.txt", "r", stdin);
freopen("o.txt", "w", stdout);
int k, tot;
cin >> m >> n >> k;
tot = m*n;
for(int i = 1; i <= m*n; i++)
pre[i] = i;
while(k--) {
int x1, x2;
cin >> x1 >> x2;
int root1 = unionsearch(x1);
int root2 = unionsearch(x2);
if(root1 != root2) {
pre[root1] = root2;
tot--;
}
}
cout << tot << endl;
return 0;
}
例二:
http://lx.lanqiao.cn/problem.page?gpid=T453
先使用并查集查找导致形成环的一条线(两个点),然后根据对其中一个点开始dfs,便可以得到环。
#include <iostream>
#include <cstdio>
#include <map>
#include <string>
#include <vector>
#include <algorithm>
#include <sstream>
#include <cstring>
#include <cmath>
#include <stack>
#include <queue>
#include <set>
#include <iomanip>
using namespace std;
const int maxn = 100010;
int n, ed;
vector<int> chart[maxn], ans;
int pre[maxn], vis[maxn];
bool dfs(int u) {
if(vis[u]) {
if(u == ed) return true;
return false;
}
vis[u] = 1;
for(int i = 0; i < chart[u].size(); i++) {
int v = chart[u][i];
if(dfs(v)) {
ans.push_back(v);
return true;
}
}
return false;
}
int union_find(int cur) {
int root = cur, tmp;
while(root != pre[root]) root = pre[root];
while(cur != root) {
tmp = pre[cur];
pre[cur] = root;
cur = tmp;
}
return root;
}
int main() {
cin >> n;
memset(pre, 0, sizeof(pre));
memset(vis, 0, sizeof(vis));
for(int i = 1; i <= n; i++) pre[i] = i;
for(int i = 0; i < n; i++) {
int x, y;
cin >> x >> y;
chart[x].push_back(y);
chart[y].push_back(x);
int root1 = union_find(x);
int root2 = union_find(y);
if(root1 != root2) pre[root1] = pre[root2];
else {
ed = x;
dfs(x);
}
}
sort(ans.begin(), ans.end());
for(int i = 0; i < ans.size(); i++)
cout << ans[i] << " ";
return 0;
}