知乎讲解链接
这个感觉讲的挺好的
带权并查集作用:它可以用 d d d 数组记录下每个节点相对于根节点的距离。
#define MAXN 50005
int fa[MAXN],d[MAXN];
int find(int x) {
if(fa[x] == x) return x;
else {
int oldFa = fa[x];
fa[x] = find(oldFa);
d[x] = d[x] + d[oldFa];
return fa[x];
}
}
void merge(int x, int y, int w) {
int fax = find(x), fay = find(y);
if(fax == fay) return;
fa[fax] = fay;
d[fax] = -d[x] + d[y] + w;
// d[x] + d[fax] = w + d[y]
// 将 x 所在分支合并到 y,节点 x 到 y 的距离是 w
}
int dist(int x, int y)
{// 把 x 分支合并到 y,那么距离就是 dis[x] - dis[y],反之就是 dis[y] - dis[x]
int fax = find(x),fay = find(y);
if(fax != fay) return -1;
else return d[x] - d[y];
}
关于合并操作,把
3
,
4
3,4
3,4 分支合并到
1
,
2
1,2
1,2 上
f[fx] = fy;
首先我们会知道一条连接两个分支的边的长度为
w
w
w
这里为连接节点
3
,
4
3,4
3,4 一条长度为
3
3
3 的边
牢记
d
[
x
]
d[x]
d[x] 表示
x
x
x 节点到根节点的距离
合并之前
d
[
f
x
]
=
0
d[fx]=0
d[fx]=0
可以根据矢量关系得到如下式子
d
[
x
]
d[x]
d[x] 为
x
x
x 到
f
x
fx
fx 的距离,
d
[
y
]
d[y]
d[y] 为
y
y
y 到
f
y
fy
fy 的距离
d
[
f
x
]
+
d
[
x
]
=
w
+
d
[
y
]
⟹
d
[
f
x
]
=
w
+
d
[
y
]
−
d
[
x
]
d[fx]+d[x]=w+d[y]\\ \Longrightarrow d[fx]=w+d[y]-d[x]
d[fx]+d[x]=w+d[y]⟹d[fx]=w+d[y]−d[x]
P2024 [NOI2001] 食物链
思路:
带权并查集可以很容易地知道两点的距离,不过依照题意,只有
A
、
B
、
C
A、B、C
A、B、C 三种动物,所以距离只能在
0
,
1
,
2
0,1,2
0,1,2 里取,只要将模板里的距离都
m
o
d
3
\mod 3
mod3 就可以了。
比如在
A
−
>
B
−
>
C
−
>
A
A->B->C->A
A−>B−>C−>A,这个食物链里,
d
i
s
t
(
B
,
A
)
=
1
dist(B,A)=1
dist(B,A)=1 表示
A
A
A 捕食
B
B
B,
d
i
s
t
(
C
,
A
)
=
2
dist(C,A)=2
dist(C,A)=2 表示
C
C
C 捕食
A
A
A,如果算出来距离等于
0
0
0 的话,就表示两者是同类。
code:
#include<bits/stdc++.h>
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define ld long double
#define all(x) x.begin(), x.end()
#define eps 1e-6
using namespace std;
const int maxn = 2e5 + 9;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
ll n, m;
int f[maxn], d[maxn];
int find(int x)
{
if(f[x] != x)
{
int oldf = f[x];
f[x] = find(oldf);
d[x] = (d[x] + d[oldf]) % 3;
}
return f[x];
}
void merge(int x, int y, int w)
{
int fx = find(x), fy = find(y);
if(fx == fy) return;
f[fx] = fy;
d[fx] = (-d[x] + d[y] + w + 3) % 3;
}
int dist(int x, int y)
{
int fx = find(x), fy = find(y);
if(fx != fy) return -1;
else return (d[x] - d[y] + 3) % 3;
}
void work()
{
cin >> n >> m;
for(int i = 1; i <= n; ++i) f[i] = i;
int ans = 0;
for(int i = 0; i < m; ++i)
{
int op, x, y;cin >> op >> x >> y;
if(x > n || y > n) ++ans;
else
{
int fx = find(x), fy = find(y);
if(op == 1)
{
if(fx != fy) merge(x, y, 0);
else if(dist(x, y) != 0) ++ans;
}
else
{
if(fx != fy) merge(x, y, 1);
else if(dist(x, y) != 1) ++ans;
}
}
}
cout << ans;
}
int main()
{
ios::sync_with_stdio(0);
// int TT;cin>>TT;while(TT--)
work();
return 0;
}
P2294 [HNOI2005]狡猾的商人
题意:
给定
n
n
n 和
m
m
m,表示序列长度和限制条件个数
每个限制条件给定一个左端点
x
x
x,右端点
y
y
y,以及区间和
∑
i
=
x
i
=
y
a
i
=
z
\sum_{i=x}^{i=y}a_i = z
∑i=xi=yai=z,判断条件是不是存在冲突
n
<
100
,
w
<
1000
,
x
<
=
y
n<100,w<1000,x<=y
n<100,w<1000,x<=y
思路:
带权并查集(虽然是道差分约束裸题
维护前缀和
题解可以看洛谷题解里天才byt的讲解
这里我说说对合并细节的理解
如果考虑
1
5
8
1 \ 5 \ 8
1 5 8 和
5
11
12
5 \ 11 \ 12
5 11 12 这两个数据
如果要满足
d
i
s
[
5
]
−
d
i
s
[
1
]
=
8
dis[5]-dis[1]=8
dis[5]−dis[1]=8,那么我们需要如下图这么合并
也就是合并时把左端点当作根节点
约束条件形如
x
y
z
x \ y \ z
x y z
int fx = find(x), fy = find(y);
dis[fy] = z + dis[x] - dis[y];
f[fy] = fx;
因为我们判断 x y z x \ y \ z x y z 是否合法时这么写的,所以合并操作需要按照上边
if(dis[y] - dis[x] == z) return 1;
code:
#include<bits/stdc++.h>
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define ld long double
#define all(x) x.begin(), x.end()
#define eps 1e-6
using namespace std;
const int maxn = 2e3 + 9;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
ll n, m;
ll f[maxn], dis[maxn];
int find(int x)
{
if(f[x] != x)
{
int oldf = f[x];
f[x] = find(f[x]);
dis[x] += dis[oldf];
}
return f[x];
}
bool merge(int x, int y, int z)
{
int fx = find(x), fy = find(y);
if(fx == fy)
{
if(dis[y] - dis[x] == z) return 1;
else return 0;
}
dis[fy] = z + dis[x] - dis[y];
f[fy] = fx;
return 1;
}
void work()
{
cin >> n >> m;
for(int i = 0; i <= n + 1; ++i) f[i] = i, dis[i] = 0;
bool flag = 1;
for(int i = 1, x, y, z; i <= m; ++i)
{
cin >> x >> y >> z;
++y;
if(!flag) continue;
flag = merge(x, y, z);
}
if(flag) cout << "true\n";
else cout << "false\n";
}
int main()
{
ios::sync_with_stdio(0);
int TT;cin>>TT;while(TT--)
work();
return 0;
}
How Many Answers Are Wrong
题意:
给出区间
[
1
,
n
]
[1,n]
[1,n],下面有
m
m
m 组数据,
l
,
r
,
v
l,r,v
l,r,v 表示
[
l
,
r
]
[l,r]
[l,r] 区间和为
v
v
v,每输入一组数据,判断此组条件是否与前面冲突,输出冲突的数据的个数
思路:
带权并查集
和上一个一样,这个题是记录冲突的数据的个数
code:
#include<bits/stdc++.h>
#define endl '\n'
#define ll long long
#define ull unsigned long long
#define ld long double
#define all(x) x.begin(), x.end()
#define eps 1e-6
using namespace std;
const int maxn = 2e5 + 9;
const int mod = 1e9 + 7;
const int inf = 0x3f3f3f3f;
const ll INF = 0x3f3f3f3f3f3f3f3f;
ll n, m;
ll f[maxn], dis[maxn];
ll ans = 0;
int find(int x)
{
if(f[x] != x)
{
int oldf = f[x];
f[x] = find(f[x]);
dis[x] += dis[oldf];
}
return f[x];
}
int merge(int x, int y, int z)
{
int fx = find(x), fy = find(y);
if(fx == fy)
{
if(dis[y] - dis[x] != z) return 1;
else return 0;
}
dis[fy] = z + dis[x] - dis[y];
f[fy] = fx;
return 0;
}
void work()
{
for(int i = 0; i <= n + 1; ++i) f[i] = i, dis[i] = 0;
ans = 0;
for(int i = 1, x, y, z; i <= m; ++i)
{
cin >> x >> y >> z;
++y;// 整体区间右移一个
ans += merge(x, y, z);
}
cout << ans << endl;
}
int main()
{
ios::sync_with_stdio(0);
// int TT;cin>>TT;while(TT--)
while(cin >> n >> m)
work();
return 0;
}