Python做控制变量实验的工具

在控制变量实验中,我们通常需要固定住一个或几个参数,并遍历一个区间将参数代入实验中。
假设一个叫func的函数有三个参数,第一个参数固定,第二、三个参数是需要控制变量的,那么在参数离散取值的情况下,罗列出第二、第三个参数所有的情况,就需要用排列组合将不同参数代入func进而得到实验结果。本人实现了这样一个类,现在开源给大家使用。希望对大家有帮助

# coding:utf8
import time
import copy
import pandas as pd

class Experimenter:
    def __init__(self, verbose, log_filename='explog_{}.log'):
        self.statistic = []
        self.verbose = verbose
        self.log_filename = log_filename

    def statistic_to_csv(self):
        k = [i for i in self.statistic[0]]
        data = []

        for record in self.statistic:
            sub_data = [record[i] for i in k]
            data.append(sub_data)
        df = pd.DataFrame(data)
        df.columns = k
        df.to_csv(self.log_filename.format(time.strftime('%Y-%m-%d %H-%M-%S', time.gmtime(time.time()+8*60*60))))

    @staticmethod
    def arg_factory(args, scope):
        def inc(idx):
            scope_arg_index[idx] += 1
            if scope_arg_index[idx] >= scope_arg_length[idx]:
                scope_arg_index[idx] = 0
                if idx != scope_arg_num-1:
                    inc(idx+1)
                    return 0
                else:
                    return 0
            else:
                return 0

        scope_arg_length = [len(scope[i]) for i in scope]
        scope_arg_index = [0 for _ in scope]
        scope_arg_num = len(scope)
        while True:
            _tmp_args = copy.deepcopy(args)
            _idx = 0
            for ar in scope:
                _tmp_args[ar] = scope[ar][scope_arg_index[_idx]]
                _idx += 1
            yield _tmp_args
            FLAG = False
            for _i in range(scope_arg_num):
                if scope_arg_length[_i] - 1 > scope_arg_index[_i]:
                    FLAG = True
            if not FLAG:
                break
            inc(0)

    def run(self, execute_func, args={}, args_scope={}): # args中,需要尝试的args,留空,放入args_scope中
        for k in args:
            if k is None:
                assert k in args_scope, '{} 参数未被传入!!!'.format(k)
        get_next_args = lambda x : x.next()
        arg_fac = self.arg_factory(args, args_scope)
        while True:
            try:
                current_args = get_next_args(arg_fac)
                if self.verbose:
                    print(current_args)
                res = execute_func(**current_args)
                assert isinstance(res, dict), 'func result must be dict!'
                for k in current_args:
                    if k in res:
                        res['_'+k+'_'] = current_args[k]
                    else:
                        res[k] = current_args[k]
                self.statistic.append(res)
                if self.verbose:
                    print('-' * 150)
                    print('-' * 150)
                    print('完成一个任务。')

            except:
                if self.verbose:
                    print('所有参数被执行完毕,任务结束~')
                break
        self.statistic_to_csv()

以上是类,下面介绍用法:
第一步:实例化(需指定verbose、log_filepath)
第二步:传入做实验的用的函数(注意函数不要带括号)、实验函数的参数dict(所有关键字参数都要包含在键中,固定参数直接传入值,需改变参数写None)、实验的变量dict(用键值对形式,关键字参数为键,需要更改的值放在一个列表中)
下面用一个实例来说明,很简单,一看就会了。
假设要改变func的b、c参数的值进行实验:

    def func(a, b, c):
        return {'a+b':a+b, 'a+c': a+c}

    exp = Experimenter(verbose=False)
    exp.run(execute_func=func, args={'a': 1, 'b':None, 'c':None}, args_scope={'b':[1, 2, 3], 'c':[4, 5, 6]})

运行后,即可在同目录下自动生成的csv文件中查看结果:

a+c,a+b,c,b,a
0,5,2,4,1,1
1,6,2,5,1,1
2,7,2,6,1,1
3,5,3,4,2,1
4,6,3,5,2,1
5,7,3,6,2,1
6,5,4,4,3,1
7,6,4,5,3,1
8,7,4,6,3,1

注:实验函数的返回值必须是dict,空dict也可以。

希望能帮到大家~

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值