文章目录
CP-SAT的callback功能
作为约束规划求解器,CP-SAT的callback往往用于获取求解得到的可行解。每当求解器获得一个可行解时,则将可行解返回给callback,并进行相应的操作,那如何使用CP-SAT的callback功能呢?
相关案例代码
无目标的约束满足问题
给出如下问题:有三个变量 x,y,x
,它们的取值范围均为 {0,1,2}。求满足约束条件
x
≠
y
x\neq y
x=y 的 x,y,x
组合。
这里求的取值组合其实默认满足两个条件:一是变量取值要在定义域内,二是需要满足 x ≠ y x\neq y x=y。具体的代码如下:
from ortools.sat.python import cp_model
class VarArraySolutionPrinter(cp_model.CpSolverSolutionCallback):
"""Print intermediate solutions."""
def __init__(self, variables, count_limit):
print("callback initialization begins")
cp_model.CpSolverSolutionCallback.__init__(self)
self.__variables = variables
self.__solution_count = 0
self.__solution_limit = count_limit
print("callback initialization end")
def on_solution_callback(self):
"""callback function is called"""
self.__solution_count += 1
for v in self.__variables:
print('%s=%i' % (v, self.Value(v)), end=' ')
print()
if self.__solution_count >= self.__solution_limit:
print('Stop search after %i solutions' % self.__solution_limit)
else:
self.StopSearch()
def solution_count(self):
"""返回可行解数(callback次数)"""
return self.__solution_count
m = cp_model.CpModel()
x = m.NewIntVar(0, 2, 'x')
y = m.NewIntVar(0, 2, 'y')
z = m.NewIntVar(0, 2, 'z')
m.Add(x != y)
solver = cp_model.CpSolver()
solver.parameters.enumerate_all_solutions = True
solution_printer = VarArraySolutionPrinter([x, y, z], 50)
status = solver.Solve(m, solution_printer)
print('Status = %s' % solver.StatusName(status))
print('Number of solutions found: %i' % solution_printer.solution_count())
cp_model.CpSolverSolutionCallback
是一个回调类,每当求解器得到一个可行解时(不为空),则执行on_solution_callback
函数,将当前可行解传给回调类。
Solve(model, callback) 方法
上述代码中,solver.Solve()
传入两个参数,一个是问题模型 CpModel()
,一个是回调类 VarArraySolutionPrinter()
,我们看相应的源码可知,增加回调类使得 Solve()
中的 solve_wrapper
增加了两个操作:
solve_wrapper.AddSolutionCallback(solution_callback)
solve_wrapper.ClearSolutionCallback(solution_callback)
这两个操作底层均为 C++ 封装,因此我们知道这一层即可,回调类是在 Solve()
方法当中执行这两句语句,且值得注意的是,Solve()
是 solver
的成员方法,在回调类中的 self
指的是 solver
本身,可以在回调类当中读取 solver
的当前属性值。
SolveWithSolutionCallback(model, callback) 方法
该方法与Solve()
方法用法完全一致,但目前已建议弃用。将上述的 Solve()
方法换成 SolveWithSolutionCallback
方法即可,且输出结果完全一致。
status = solver.SolveWithSolutionCallback(m, solution_printer)
来看下该方法的源码,可知该方法最终还是返回并调用 Solve()
方法。
def SolveWithSolutionCallback(self, model, callback):
warnings.warn(
'SolveWithSolutionCallback is deprecated; use Solve() with' +
'the callback argument.', DeprecationWarning)
return self.Solve(model, callback)
SearchForAllSolutions(model, callback)方法
该方法与前述两种方法用法一致,该方法是找到问题中所有的满足约束的解,并将解传给回调类。与前面两种方法的区别是,该方法仅支持无目标的约束满足问题。
在官方文档中,也建议弃用该方法,而替换方案也是用Solve()
方法,SearchForAllSolutions(model, callback)
与下面两行代码是等价的:
solver.parameters.enumerate_all_solutions = True
status = solver.Solve(model, callback)
且在用 Solve()
的等价形式时,即时添加了问题目标,也不会报错;而用 SearchForAllSolutions
方法时,在问题添加了目标函数之后,运行会报错。
带目标的优化问题
延续上面的例子,这里我们为问题增加目标:最大化 x + y − z x+y-z x+y−z。仅增加下面一行代码:
m.Maximize(x+y-z)
前面提到,在回调类中,self
指的是 solver
本身,因此我们在回调类中加入打印 solver
的部分信息的操作。
class VarArraySolutionPrinter(cp_model.CpSolverSolutionCallback):
"""Print intermediate solutions."""
def __init__(self, variables, count_limit):
print("callback initialization begins")
cp_model.CpSolverSolutionCallback.__init__(self)
self.__variables = variables
self.__solution_count = 0
self.__solution_limit = count_limit
print("callback initialization end")
def on_solution_callback(self):
"""callback function is called"""
self.__solution_count += 1
for v in self.__variables:
print('%s=%i' % (v, self.Value(v)), end=' ')
"""solver信息"""
print('Obj = %i' % self.ObjectiveValue(), end=', ')
print('BestBound= %i' % self.BestObjectiveBound(), end=', ')
print('WallTime: %f s' % self.WallTime())
print()
if self.__solution_count >= self.__solution_limit:
print('Stop search after %i solutions' % self.__solution_limit)
self.StopSearch()
def solution_count(self):
"""返回可行解数(callback次数)"""
return self.__solution_count
在每次回调可行解时,顺带打印当前解的目标值、最优解以及求解时间,返回结果如下:
callback initialization begins
callback initialization end
x=1 y=2 z=0 Obj = 3, BestBound= 3, WallTime: 0.001990 s
这里由于目标非常简单,因此基于目标函数直接取边界点可以到达最优,回调函数就只被调用一次就结束,对于复杂的问题,求解过程找到的可行解会被打印输出。
OR-Tools自带的三种callback类
上述内容介绍了OR-Tools中如何自定义callback类,主要有两步,一个是继承并初始化基类 CpSolverSolutionCallback,接着是重写 on_solution_callback() 方法,在求解过程中,solver
会不断地调用基类当中的 OnSolutionCallback() 方法,该方法返回的就是 on_solution_callback() 方法。
除了自定义callback类之外,OR-Tools自带了三种回调函数,如果不是有特别的定制化需求,其实OR-Tools自带的三种回调函数也能使用,且使用方式相同简便。
1. ObjectiveSolutionPrinter()
该类的使用方式非常简单,先定义一个回调对象传入 solver.Solve()
即可。
solution_printer = cp_model.ObjectiveSolutionPrinter()
status = solver.Solve(m, solution_printer)
查看源码可知,该类的写法与文首介绍的写法相近,大家可以根据源码的 on_solution_callback()
方法知道该方法返回的内容,包括返回的解的序号、得到当前解经过的时间、当前解的目标值。
class ObjectiveSolutionPrinter(CpSolverSolutionCallback):
"""Display the objective value and time of intermediate solutions."""
def __init__(self):
CpSolverSolutionCallback.__init__(self)
self.__solution_count = 0
self.__start_time = time.time()
def on_solution_callback(self):
"""Called on each new solution."""
current_time = time.time()
obj = self.ObjectiveValue()
print('Solution %i, time = %0.2f s, objective = %i' %
(self.__solution_count, current_time - self.__start_time, obj))
self.__solution_count += 1
def solution_count(self):
"""Returns the number of solutions found."""
return self.__solution_count
2. VarArrayAndObjectiveSolutionPrinter()
该类的使用方式如上,不过由于需要打印变量的信息,因此还需要在实例化的时候传入模型的变量作为参数,以文首的问题为例,我们想要打印变量 x,y,z
的变量值,则可以通过如下方式进行处理。
solution_printer = cp_model.VarArrayAndObjectiveSolutionPrinter([x, y, z])
status = solver.Solve(m, solution_printer)
查看源码的 on_solution_callback()
方法可知该回调类返回的内容,包括返回的解的序号、得到当前解经过的时间、当前解的目标值,以及打印传入变量在当前解中的值。
class VarArrayAndObjectiveSolutionPrinter(CpSolverSolutionCallback):
"""Print intermediate solutions (objective, variable values, time)."""
def __init__(self, variables):
CpSolverSolutionCallback.__init__(self)
self.__variables = variables
self.__solution_count = 0
self.__start_time = time.time()
def on_solution_callback(self):
"""Called on each new solution."""
current_time = time.time()
obj = self.ObjectiveValue()
print('Solution %i, time = %0.2f s, objective = %i' %
(self.__solution_count, current_time - self.__start_time, obj))
for v in self.__variables:
print(' %s = %i' % (v, self.Value(v)), end=' ')
print()
self.__solution_count += 1
def solution_count(self):
"""Returns the number of solutions found."""
return self.__solution_count
3. VarArraySolutionPrinter()
这是OR-Tools的第三种回调类,使用方法与前两种类一致,如果需要打印变量在可行解当中的值,可以在实例化回调对象时将变量传入。如下所示:
solution_printer = cp_model.VarArraySolutionPrinter([x, y, z])
status = solver.Solve(m, solution_printer)
类似的,查看的源码可知,该回调类每次有可行解时,除了不打印可行解的目标值,其余打印信息与 VarArrayAndObjectiveSolutionPrinter()
类一致。
class VarArraySolutionPrinter(CpSolverSolutionCallback):
"""Print intermediate solutions (variable values, time)."""
def __init__(self, variables):
CpSolverSolutionCallback.__init__(self)
self.__variables = variables
self.__solution_count = 0
self.__start_time = time.time()
def on_solution_callback(self):
"""Called on each new solution."""
current_time = time.time()
print('Solution %i, time = %0.2f s' %
(self.__solution_count, current_time - self.__start_time))
for v in self.__variables:
print(' %s = %i' % (v, self.Value(v)), end=' ')
print()
self.__solution_count += 1
def solution_count(self):
"""Returns the number of solutions found."""
return self.__solution_count