1. 问题描述:
幼儿园里有 N 个小朋友,老师现在想要给这些小朋友们分配糖果,要求每个小朋友都要分到糖果。但是小朋友们也有嫉妒心,总是会提出一些要求,比如小明不希望小红分到的糖果比他的多,于是在分配糖果的时候, 老师需要满足小朋友们的 K 个要求。幼儿园的糖果总是有限的,老师想知道他至少需要准备多少个糖果,才能使得每个小朋友都能够分到糖果,并且满足小朋友们所有的要求。
输入格式
输入的第一行是两个整数 N,K。接下来 K 行,表示分配糖果时需要满足的关系,每行 3 个数字 X,A,B。
如果 X=1.表示第 A 个小朋友分到的糖果必须和第 B 个小朋友分到的糖果一样多。
如果 X=2,表示第 A 个小朋友分到的糖果必须少于第 B 个小朋友分到的糖果。
如果 X=3,表示第 A 个小朋友分到的糖果必须不少于第 B 个小朋友分到的糖果。
如果 X=4,表示第 A 个小朋友分到的糖果必须多于第 B 个小朋友分到的糖果。
如果 X=5,表示第 A 个小朋友分到的糖果必须不多于第 B 个小朋友分到的糖果。
小朋友编号从 1 到 N。
输出格式
输出一行,表示老师至少需要准备的糖果数,如果不能满足小朋友们的所有要求,就输出 −1。
数据范围
1 ≤ N < 10 ^ 5,
1 ≤ K ≤ 10 ^ 5,
1 ≤ X ≤ 5,
1 ≤ A,B ≤ N
输入样例:
5 7
1 1 2
2 3 2
4 4 1
3 4 5
5 4 5
2 3 5
4 5 1
输出样例:
11
来源:https://www.acwing.com/problem/content/description/1171/
2. 思路分析:
分析题目可以知道题目中存在五个不等式对应的不等式组(等式可以转换为两个不等式),属于经典的差分约束的题目,可以参照差分约束的博客,我们理解其中的原理之后记住结论那么可以很快解决这个问题,如果求解的是最小值所以求解的是单源最长路径,如果求解的是最大值那么求解的是单源最短路径,单源最短路径与最长路径是对称的,我们在求解的时候只需要修改不等式的符号即可。首先需要将题目中的关系转换成对应的不等式组,因为求解的是"至少",所以求解的是最小值也即需要求解单源最长路径,那么就需要将不等式的符号修改为">"号的形式,可以得到下面的不等式组:
- X = 1:A = B <=> A >= B,B >= A
- X = 2:A < B <=> B >= A + 1
- X = 3:A >= B <=> A >= B
- X = 4:A > B <=> A >= B + 1
- X = 5:A <= B <=> B >= A
可以发现不管求解的是最短路径还是最长路径,在建图的时候都是不等号右边节点向左边节点连一条边,例如B >= A那么由A向B建一条边权为ck的边,并且我们还需要保证的一个前提是需要找一个源点使得从源点出发一定可以到达所有边,这里其实可以建一个虚拟源点,那么就一定可以到达每一个节点也就可以到达每一条边,根据题目的条件可知我们需要使得每个小朋友至少分到一个糖果,所以xi >= 1,建立一个虚拟源点满足的不等式关系为xi >= x0 + 1,其中x0 = 0,与原不等式是等价的,所以我们可以将0号点作为虚拟源点,向1~n号点连一条权重为1的边,然后使用spfa求解一遍单源最长路径即可,由于这道题目的数据局范围有点大,所以在使用spfa求解的时候如果不优化会超时,这里使用到的一个小技巧是使用一个变量cnt来记录迭代的次数,如果cnt超过一定的迭代次数还没有结束那么我们认为是图中是存在正环的(最长路判断的是正环),不过如果数据比较强的时候回判断错误,如果存在正环说明无解(不等式中存在矛盾的情况),返回-1,所以使用spfa算法的好处是不仅可以求解出最短/长路径,而且可以判断出是否存在正环或者负环。本质上差分约束问题与图论中的最短路与最长路是等价的,所以我们在遇到差分约束问题的时候将其转化为图论的问题即可,下面是图中测试用例建的图:
3. 代码如下:
from typing import List
import collections
class Solution:
def spfa(self, n: int, g: List[List[int]]):
# 因为求解的是最长路径所以初始化为负无穷
INF = -10 ** 10
dis = [INF] * (n + 1)
dis[0] = 0
vis = [0] * (n + 1)
vis[0] = 1
# count列表用来记录最长路径的边数,判断是否存在正环也即无解的情况
count = [0] * (n + 1)
q = collections.deque()
q.append(0)
cnt = 0
while q:
p = q.popleft()
vis[p] = 0
for next in g[p]:
# 求解的是最长路
if dis[next[0]] < dis[p] + next[1]:
dis[next[0]] = dis[p] + next[1]
count[next[0]] = count[p] + 1
if count[next[0]] >= n + 1:
return -1
cnt += 1
# 当迭代超过一定次数的时候我们认为图中存在正环提前退出算法
if cnt >= 3 * n: return -1
if vis[next[0]] == 0:
q.append(next[0])
res = 0
for i in range(1, n + 1):
res += dis[i]
return res
def process(self):
n, m = map(int, input().split())
g = [list() for i in range(n + 1)]
for i in range(m):
x, a, b = map(int, input().split())
if x == 1:
g[a].append((b, 0))
g[b].append((a, 0))
elif x == 2:
# a向b连一条权重为1的边
g[a].append((b, 1))
elif x == 3:
g[b].append((a, 0))
elif x == 4:
g[b].append((a, 1))
else:
g[a].append((b, 0))
# 注意0号点向其他点连一条边, 因为满足限制 x >= x0 + 1, 其中x0 = 0 (本质上是一个虚拟源点这样可以确保求解到最值)
for i in range(1, n + 1):
g[0].append((i, 1))
return self.spfa(n, g)
if __name__ == "__main__":
print(Solution().process())