在Maraboupy的example中有一个2_TensorflowExample.py
-
先是从.pd中读取了一个network
network = Marabou.read_tf(filename = filename, inputNames = inputNames, outputName = outputName)
这个network是从tensorflow的存储文件
.pd
中读出来的,然后转化成了变量network,这个network的类型是MarabouNetWorkTF
类,这个类继承自MarabouNetwork
。MarabouNetwork
是一个抽象类,表示一个通用的Marabou Network。它有如下参数:""" Abstract class representing general Marabou network Attributes: numVars (int): Total number of variables to represent network equList (list of :class:`~maraboupy.MarabouUtils.Equation`): Network equations reluList (list of tuples): List of relu constraint tuples, where each tuple contains the backward and forward variables maxList (list of tuples): List of max constraint tuples, where each tuple conatins the set of input variables and output variable varsParticipatingInConstraints (set of int): Variables involved in some constraint lowerBounds (Dict[int, float]): Lower bounds of variables upperBounds (Dict[int, float]): Upper bounds of variables inputVars (list of numpy arrays): Input variables outputVars (numpy array): Output variables """
虽然我没细看,但我猜
MarabouNetWorkTF
、MarabouNetWorkNNet
、MarabouNetWorkONNX
这三个类都是继承自MarabouNetwork类的。 -
回头接着看example,在得到变量network以后,设置了上下界,并且通过调用
vals, stats = network.solve("marabou.log")
得到了如下结果:sat input 0 = 10.0 input 1 = -9.778978626839324 output 0 = 19.755387832554714 output 1 = 194.0
这个结果是由solve过程打印的,我接下来打算重点看一下这个solve函数。
-
这里的solve不是
MarabouNetWorkTF
中定义的,是在它的父抽象类MarabouNetwork
中直接实现的:def solve(self, filename="", verbose=True, options=None): """Function to solve query represented by this network Args: filename (string): Path for redirecting output verbose (bool): If true, print out solution after solve finishes options (:class:`~maraboupy.MarabouCore.Options`): Object for specifying Marabou options, defaults to None Returns: (tuple): tuple containing: - vals (Dict[int, float]): Empty dictionary if UNSAT, otherwise a dictionary of SATisfying values for variables - stats (:class:`~maraboupy.MarabouCore.Statistics`): A Statistics object to how Marabou performed """ ipq = self.getMarabouQuery() if options == None: options = MarabouCore.Options() vals, stats = MarabouCore.solve(ipq, options, filename) if verbose: if stats.hasTimedOut(): print("TO") elif len(vals)==0: print("unsat") else: print("sat") for j in range(len(self.inputVars)): for i in range(self.inputVars[j].size): print("input {} = {}".format(i, vals[self.inputVars[j].item(i)])) for i in range(self.outputVars.size): print("output {} = {}".format(i, vals[self.outputVars.item(i)])) return [vals, stats]
我觉得重点有两行,第一行通过
ipq = self.getMarabouQuery
,获得一个InputQuery,并在vals, stats = MarabouCore.solve(ipq, options, filename)
中通过调用c++接口来获得vals和stats。接下来重点看一下
getMarabouQuery()
和MarabouCore.solve()
-
getMarabouQuery
def getMarabouQuery(self): """Function to convert network into Marabou InputQuery Returns: :class:`~maraboupy.MarabouCore.InputQuery` """ ipq = MarabouCore.InputQuery() ipq.setNumberOfVariables(self.numVars) i = 0 for inputVarArray in self.inputVars: for inputVar in inputVarArray.flatten(): ipq.markInputVariable(inputVar, i) i+=1 i = 0 for outputVar in self.outputVars.flatten(): ipq.markOutputVariable(outputVar, i) i+=1 for e in self.equList: eq = MarabouCore.Equation(e.EquationType) for (c, v) in e.addendList: assert v < self.numVars eq.addAddend(c, v) eq.setScalar(e.scalar) ipq.addEquation(eq) for r in self.reluList: assert r[1] < self.numVars and r[0] < self.numVars MarabouCore.addReluConstraint(ipq, r[0], r[1]) for m in self.maxList: assert m[1] < self.numVars for e in m[0]: assert e < self.numVars MarabouCore.addMaxConstraint(ipq, m[0], m[1]) for l in self.lowerBounds: assert l < self.numVars ipq.setLowerBound(l, self.lowerBounds[l]) for u in self.upperBounds: assert u < self.numVars ipq.setUpperBound(u, self.upperBounds[u]) return ipq
变量ipq——也就是
getMarabouQuery()
的返回值,是一个MarabouCore.InputQuery
类型的。无语了,这个InputQuery
是定义在#include <InputQuery.h>
中的,我还没找到它在哪,晕了。小代码写得还挺长,先放一放这个ipq
。 -
MarabouCore.solve()
这个函数是在c++中编写的:
std::pair<std::map<int, double>, Statistics> solve(InputQuery &inputQuery, MarabouOptions &options, std::string redirect="") { // Arguments: InputQuery object, filename to redirect output // Returns: map from variable number to value std::map<int, double> ret; Statistics retStats; int output=-1; if(redirect.length()>0) output=redirectOutputToFile(redirect); try{ bool verbosity = options._verbosity; unsigned timeoutInSeconds = options._timeoutInSeconds; bool dnc = options._dnc; Engine engine; engine.setVerbosity(verbosity); if(!engine.processInputQuery(inputQuery)) return std::make_pair(ret, *(engine.getStatistics())); if ( dnc ) { unsigned initialDivides = options._initialDivides; unsigned initialTimeout = options._initialTimeout; unsigned numWorkers = options._numWorkers; unsigned onlineDivides = options._onlineDivides; float timeoutFactor = options._timeoutFactor; auto dncManager = std::unique_ptr<DnCManager> ( new DnCManager( numWorkers, initialDivides, initialTimeout, onlineDivides, timeoutFactor, DivideStrategy::LargestInterval, &inputQuery, verbosity ) ); dncManager->solve( timeoutInSeconds ); switch ( dncManager->getExitCode() ) { case DnCManager::SAT: { retStats = Statistics(); dncManager->getSolution( ret ); break; } case DnCManager::TIMEOUT: { retStats = Statistics(); retStats.timeout(); return std::make_pair( ret, retStats ); } default: return std::make_pair( ret, Statistics() ); // TODO: meaningful DnCStatistics } } else { if(!engine.solve(timeoutInSeconds)) return std::make_pair(ret, *(engine.getStatistics())); if (engine.getExitCode() == Engine::SAT) engine.extractSolution(inputQuery); retStats = *(engine.getStatistics()); for(unsigned int i=0; i<inputQuery.getNumberOfVariables(); ++i) ret[i] = inputQuery.getSolutionValue(i); } } catch(const MarabouError &e){ printf( "Caught a MarabouError. Code: %u. Message: %s\n", e.getCode(), e.getUserMessage() ); return std::make_pair(ret, retStats); } if(output != -1) restoreOutputStream(output); return std::make_pair(ret, retStats); }
彻底晕了,里面又提到了一个
Engine
类,是在#include <Engine.h>
中定义的,估计这个就是核心引擎了。今天就先看到这里吧…