Maraboupy阅读理解

本文深入探讨了MarabouPy库中MarabouNetwork类的功能与实现细节,特别是其solve方法的工作流程,包括从创建MarabouCore.InputQuery到调用MarabouCore.solve进行求解的过程。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

在Maraboupy的example中有一个2_TensorflowExample.py

  1. 先是从.pd中读取了一个network

    network = Marabou.read_tf(filename = filename, inputNames = inputNames, outputName = outputName)
    

    这个network是从tensorflow的存储文件.pd中读出来的,然后转化成了变量network,这个network的类型是MarabouNetWorkTF类,这个类继承自MarabouNetwork

    MarabouNetwork是一个抽象类,表示一个通用的Marabou Network。它有如下参数:

    """
    Abstract class representing general Marabou network
    
    Attributes:
        numVars (int): Total number of variables to represent network
        equList (list of :class:`~maraboupy.MarabouUtils.Equation`): Network equations
        reluList (list of tuples): List of relu constraint tuples, where each tuple contains the backward and forward variables
        maxList (list of tuples): List of max constraint tuples, where each tuple conatins the set of input variables and output variable
        varsParticipatingInConstraints (set of int): Variables involved in some constraint
        lowerBounds (Dict[int, float]): Lower bounds of variables
        upperBounds (Dict[int, float]): Upper bounds of variables
        inputVars (list of numpy arrays): Input variables
        outputVars (numpy array): Output variables
    """
    

    虽然我没细看,但我猜MarabouNetWorkTFMarabouNetWorkNNetMarabouNetWorkONNX这三个类都是继承自MarabouNetwork类的。

  2. 回头接着看example,在得到变量network以后,设置了上下界,并且通过调用vals, stats = network.solve("marabou.log")得到了如下结果:

    sat
    input 0 = 10.0
    input 1 = -9.778978626839324
    output 0 = 19.755387832554714
    output 1 = 194.0
    

    这个结果是由solve过程打印的,我接下来打算重点看一下这个solve函数。

  3. 这里的solve不是MarabouNetWorkTF中定义的,是在它的父抽象类MarabouNetwork中直接实现的:

    def solve(self, filename="", verbose=True, options=None):
        """Function to solve query represented by this network
    
        Args:
            filename (string): Path for redirecting output
            verbose (bool): If true, print out solution after solve finishes
            options (:class:`~maraboupy.MarabouCore.Options`): Object for specifying Marabou options, defaults to None
    
        Returns:
            (tuple): tuple containing:
                - vals (Dict[int, float]): Empty dictionary if UNSAT, otherwise a dictionary of SATisfying values for variables
                - stats (:class:`~maraboupy.MarabouCore.Statistics`): A Statistics object to how Marabou performed
        """
        ipq = self.getMarabouQuery()
        if options == None:
            options = MarabouCore.Options()
        vals, stats = MarabouCore.solve(ipq, options, filename)
        if verbose:
            if stats.hasTimedOut():
                print("TO")
            elif len(vals)==0:
                print("unsat")
            else:
                print("sat")
                for j in range(len(self.inputVars)):
                    for i in range(self.inputVars[j].size):
                        print("input {} = {}".format(i, vals[self.inputVars[j].item(i)]))
    
                for i in range(self.outputVars.size):
                    print("output {} = {}".format(i, vals[self.outputVars.item(i)]))
    
        return [vals, stats]
    

    我觉得重点有两行,第一行通过ipq = self.getMarabouQuery,获得一个InputQuery,并在vals, stats = MarabouCore.solve(ipq, options, filename)中通过调用c++接口来获得vals和stats。

    接下来重点看一下getMarabouQuery()MarabouCore.solve()

  4. getMarabouQuery

    def getMarabouQuery(self):
            """Function to convert network into Marabou InputQuery
    
            Returns:
                :class:`~maraboupy.MarabouCore.InputQuery`
            """
            ipq = MarabouCore.InputQuery()
            ipq.setNumberOfVariables(self.numVars)
    
            i = 0
            for inputVarArray in self.inputVars:
                for inputVar in inputVarArray.flatten():
                    ipq.markInputVariable(inputVar, i)
                    i+=1
    
            i = 0
            for outputVar in self.outputVars.flatten():
                ipq.markOutputVariable(outputVar, i)
                i+=1
    
            for e in self.equList:
                eq = MarabouCore.Equation(e.EquationType)
                for (c, v) in e.addendList:
                    assert v < self.numVars
                    eq.addAddend(c, v)
                eq.setScalar(e.scalar)
                ipq.addEquation(eq)
    
            for r in self.reluList:
                assert r[1] < self.numVars and r[0] < self.numVars
                MarabouCore.addReluConstraint(ipq, r[0], r[1])
    
            for m in self.maxList:
                assert m[1] < self.numVars
                for e in m[0]:
                    assert e < self.numVars
                MarabouCore.addMaxConstraint(ipq, m[0], m[1])
    
            for l in self.lowerBounds:
                assert l < self.numVars
                ipq.setLowerBound(l, self.lowerBounds[l])
    
            for u in self.upperBounds:
                assert u < self.numVars
                ipq.setUpperBound(u, self.upperBounds[u])
                
            return ipq
    

    变量ipq——也就是getMarabouQuery()的返回值,是一个MarabouCore.InputQuery类型的。无语了,这个InputQuery是定义在#include <InputQuery.h>中的,我还没找到它在哪,晕了。小代码写得还挺长,先放一放这个ipq

  5. MarabouCore.solve()

    这个函数是在c++中编写的:

    std::pair<std::map<int, double>, Statistics> solve(InputQuery &inputQuery, MarabouOptions &options,
                                                   std::string redirect="") {
        // Arguments: InputQuery object, filename to redirect output
        // Returns: map from variable number to value
        std::map<int, double> ret;
        Statistics retStats;
        int output=-1;
        if(redirect.length()>0)
            output=redirectOutputToFile(redirect);
        try{
            bool verbosity = options._verbosity;
            unsigned timeoutInSeconds = options._timeoutInSeconds;
            bool dnc = options._dnc;
    
            Engine engine;
            engine.setVerbosity(verbosity);
    
            if(!engine.processInputQuery(inputQuery)) return std::make_pair(ret, *(engine.getStatistics()));
            if ( dnc )
            {
                unsigned initialDivides = options._initialDivides;
                unsigned initialTimeout = options._initialTimeout;
                unsigned numWorkers = options._numWorkers;
                unsigned onlineDivides = options._onlineDivides;
                float timeoutFactor = options._timeoutFactor;
    
                auto dncManager = std::unique_ptr<DnCManager>
                    ( new DnCManager( numWorkers, initialDivides, initialTimeout, onlineDivides,
                                      timeoutFactor, DivideStrategy::LargestInterval,
                                      &inputQuery, verbosity ) );
    
                dncManager->solve( timeoutInSeconds );
                switch ( dncManager->getExitCode() )
                {
                case DnCManager::SAT:
                {
                    retStats = Statistics();
                    dncManager->getSolution( ret );
                    break;
                }
                case DnCManager::TIMEOUT:
                {
                    retStats = Statistics();
                    retStats.timeout();
                    return std::make_pair( ret, retStats );
                }
                default:
                    return std::make_pair( ret, Statistics() ); // TODO: meaningful DnCStatistics
                }
            } else
            {
                if(!engine.solve(timeoutInSeconds)) return std::make_pair(ret, *(engine.getStatistics()));
    
                if (engine.getExitCode() == Engine::SAT)
                    engine.extractSolution(inputQuery);
                retStats = *(engine.getStatistics());
                for(unsigned int i=0; i<inputQuery.getNumberOfVariables(); ++i)
                    ret[i] = inputQuery.getSolutionValue(i);
            }
        }
        catch(const MarabouError &e){
            printf( "Caught a MarabouError. Code: %u. Message: %s\n", e.getCode(), e.getUserMessage() );
            return std::make_pair(ret, retStats);
        }
        if(output != -1)
            restoreOutputStream(output);
        return std::make_pair(ret, retStats);
    }
    

    彻底晕了,里面又提到了一个Engine类,是在#include <Engine.h>中定义的,估计这个就是核心引擎了。今天就先看到这里吧…

内容概要:本文《2025年全球AI Coding市场洞察研究报告》由亿欧智库发布,深入分析了AI编程工具的市场现状和发展趋势。报告指出,AI编程工具在2024年进入爆发式增长阶段,成为软件开发领域的重要趋势。AI编程工具不仅简化了代码生成、调试到项目构建等环节,还推动编程方式从人工编码向“人机协同”模式转变。报告详细评估了主流AI编程工具的表现,探讨了其商业模式、市场潜力及未来发展方向。特别提到AI Agent技术的发展,使得AI编程工具从辅助型向自主型跃迁,提升了任务执行的智能化和全面性。报告还分析了AI编程工具在不同行业和用户群体中的应用,强调了其在提高开发效率、减少重复工作和错误修复方面的显著效果。最后,报告预测2025年AI编程工具将在精准化和垂直化上进一步深化,推动软件开发行业进入“人机共融”的新阶段。 适合人群:具备一定编程基础,尤其是对AI编程工具有兴趣的研发人员、企业开发团队及非技术人员。 使用场景及目标:①了解AI编程工具的市场现状和发展趋势;②评估主流AI编程工具的性能和应用场景;③探索AI编程工具在不同行业中的具体应用,如互联网、金融、游戏等;④掌握AI编程工具的商业模式和盈利空间,为企业决策提供参考。 其他说明:报告基于亿欧智库的专业研究和市场调研,提供了详尽的数据支持和前瞻性洞察。报告不仅适用于技术从业者,也适合企业管理者和政策制定者,帮助他们在技术和商业决策中更好地理解AI编程工具的价值和潜力。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值