python中的argparse模块

1. argparse模块简介

  • argparse是Python中的一个常用模块,和sys.argv()功能类似,主要用于编写命令行接口:对于程序所需要的参数,它可以进行正确的解析。另外,argparse还可以自动的生成help和 usage信息,当程序的参数无效时,它可以自动生成错误信息。

  • 官方文档

    https://docs.python.org/3/library/argparse.html


2. argparse例子

  • 下列代码是python程序,它可以接受多个整数,并返回它们的或者最大值

    import argparse
    
    parser = argparse.ArgumentParser(description='处理一些整数')
    parser.add_argument('integers', metavar='N', type=int, nargs='+',
                        help='累加器中的整数')
    parser.add_argument('--sum', dest='accumulate', action='store_const',
                        const=sum, default=max,
                        help='对整数进行求和(默认: 寻找最大值)')
    args = parser.parse_args()
    print(args.accumulate(args.integers))
    

    假定上述代码保存为test.py,它可以在命令行执行并自动提供有用的help信息:

    (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python test.py -h
    usage: test.py [-h] [--sum] N [N ...]
    
    处理一些整数
    
    positional arguments:
      N           累加器中的整数
    
    optional arguments:
      -h, --help  show this help message and exit
      --sum       对整数进行求和(默认: 寻找最大值)
    
    

    该程序可以接受相应的参数并给出相应的输出:

    (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python test.py 1 2 3 4 
    4
    (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python test.py 1 2 3 4 --sum
    10
    
      
      如果传入无效的参数,它可以自动生成`error`信息:
      
      ```bash
      (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python test.py  a b 
      usage: test.py [-h] [--sum] N [N ...]
    test.py: error: argument N: invalid int value: 'a'
    

3. argparse 的三个主要函数

Creating a parser

  • parser=argparse.ArgumentParser()

    使用argparse的第一步是创建ArgumentParser对象:

    parser = argparse.ArgumentParser(description='处理一些整数')
    

    ArgumentParser对象保存了所有必要的信息,用以将命令行参数解析为相应的python数据类型。

Adding arguments

  • parser.add_argument()

    调用add_argument()向ArgumentParser对象添加命令行参数信息,这些信息告诉ArgumentParser对象如何处理命令行参数。可以通过调用parse_agrs()来使用这些命令行参数。例如:

    parser.add_argument('integers', metavar='N', type=int, nargs='+',
                        help='累加器中的整数')
    parser.add_argument('--sum', dest='accumulate', action='store_const',
                        const=sum, default=max,
                        help='对整数进行求和(默认: 寻找最大值)')
    

    之后,调用parse.args()将返回一个对象,它具有两个属性:integersaccumulate,前者是一个整数集合,后者根据命令行参数的不同,代表不同的函数,可以是max(),也可以是sum()。

Parsing arguments

  • args = parser.parse_args()

    通过调用parse_args()来解析ArgumentParser对象中保存的命令行参数:将命令行参数解析成相应的数据类型并采取相应的动作,它返回一个Namespace对象。

    >>> parser.parse_args(['--sum', '7', '-1', '42'])
    Namespace(accumulate=<built-in function sum>, integers=[7, -1, 42])
    

    在实际的python脚本中,parse_args()一般并不使用参数,它的参数由sys.argv决定。


4. ArgumentParser对象(ArgumentParser objects)

class argparse.ArgumentParser(prog=None
                            , usage=None
                            , description=None
                            , epilog=None
                            , parents=[]
                            , formatter_class=argparse.HelpFormatter
                            , argument_default=None
                            , conflict_handler=’error’
                            , add_help=True
                            , allow_abbrev=True
                            )

对象参数简介

  • 创建一个新的 ArgumentParser 对象。所有参数都应作为关键字参数传递。每个参数在下面都有更详细的描述,但简而言之,它们是:
    • prog - The name of the program (default: sys.argv[0])
      • 程序文件名
    • usage - The string describing the program usage (default: generated from arguments added to parser)
      • 程序使用说明
    • description - Text to display before the argument help (default: none)
      • 程序目的说明
    • epilog - Text to display after the argument help (default: none)
      • 程序说明后记
    • parents - A list of ArgumentParser objects whose arguments should also be included
      • ArgumentParser对象的父对象的参数列表。
    • formatter_class - A class for customizing the help output
      • help信息的说明格式
    • prefix_chars - The set of characters that prefix optional arguments (default: ‘-‘)
      • 命令行参数前缀。
    • fromfile_prefix_chars - The set of characters that prefix files from which additional arguments should be read (default: None)
    • argument_default - The global default value for arguments (default: None)
      • 参数全局默认值
    • conflict_handler - The strategy for resolving conflicting optionals (usually unnecessary)
      • 冲突处理
    • add_help - Add a -h/--help option to the parser (default: True)
      • 是否增加help选项
    • allow_abbrev - Allows long options to be abbreviated if the abbreviation is unambiguous. (default: True)
      • 是否使用参数缩写
    • exit_on_error - Determines whether or not ArgumentParser exits with error info when an error occurs. (default: True)

常用对象参数具体用法

  • prog(程序文件名):

    默认情况下,ArgumentParser对象使用sys.argv[0]来决定如何在help信息中显示程序的文件名。
    例如:下列代码的文件名为myprogram.py

    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--foo', help='foo help')
    args = parser.parse_args()
    

    它的help信息将显示程序的文件名,而不管程序在何处进行调用。

    (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python myprogram.py  --help
    usage: myprogram.py [-h] [--foo FOO]
    
    optional arguments:
      -h, --help  show this help message and exit
      --foo FOO   foo help
    (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ cd ..
    (base) mengzhichaodeMacBook-Pro:16、python代码 mengzhichao$ python Sample_Code/myprogram.py --help
    usage: myprogram.py [-h] [--foo FOO]
    
    optional arguments:
      -h, --help  show this help message and exit
      --foo FOO   foo help
    

    使用prog参数可以修改ArgumentParser对象的默认文件名:

    $parser = argparse.ArgumentParser(prog='myprogram')
    $parser.print_help()
    usage: myprogram [-h]
    
    optional arguments:
     -h, --help  show this help message and exit
    

    不论程序的文件名是来自sys.argv[0],不是来自参数prog,都可以通过%(prog)s来引用程序的文件名。

    $ parser = argparse.ArgumentParser(prog='myprogram')
    $ parser.add_argument('--foo', help='foo of the %(prog)s program')
    $ parser.print_help()
    usage: myprogram [-h] [--foo FOO]
    
    optional arguments:
     -h, --help  show this help message and exit
     --foo FOO   foo of the myprogram program
    
  • usage 程序使用说明

    默认情况下,ArgumentParser对象会根据它所包括的参数来自动的生成使用说明。

    $ parser = argparse.ArgumentParser(prog='PROG')
    $ parser.add_argument('--foo', nargs='?', help='foo help')
    $ parser.add_argument('bar', nargs='+', help='bar help')
    $ parser.print_help()
    usage: PROG [-h] [--foo [FOO]] bar [bar ...]
    
    positional arguments:
     bar          bar help
    
    optional arguments:
     -h, --help   show this help message and exit
     --foo [FOO]  foo help
    

    可以使用参数usage来修改程序的使用说明信息。

    $ parser = argparse.ArgumentParser(prog='PROG', usage='%(prog)s [options]')
    $ parser.add_argument('--foo', nargs='?', help='foo help')
    $ parser.add_argument('bar', nargs='+', help='bar help')
    $ parser.print_help()
    usage: PROG [options]
    
    positional arguments:
     bar          bar help
    
    optional arguments:
     -h, --help   show this help message and exit
     --foo [FOO]  foo help
    
  • description 程序目的说明

    一般情况下,ArgumentParser对象会使用参数description来说明程序的用途以及目的。description信息一般位于usage信息和help信息之间

    $ parser = argparse.ArgumentParser(description='A foo that bars')
    $ parser.print_help()
    usage: argparse.py [-h]
    
    A foo that bars
    
    optional arguments:
     -h, --help  show this help message and exit
    

    默认情况下,description会自动调整间距来适应显示空间。可以通过formatter_class类来修改它的显示方式。

  • epilog 程序说明后记

    一些程序喜欢在参数描述信息之后添加额外的信息,这些信息可以通过参数epilog来指定。

    $  parser = argparse.ArgumentParser(
    ...     description='A foo that bars',
    ...     epilog="And that's how you'd foo a bar")
    $ parser.print_help()
    usage: argparse.py [-h]
    
    A foo that bars
    
    optional arguments:
     -h, --help  show this help message and exit
    
    And that's how you'd foo a bar
    

    description类似,epilog会自动调整间距来适应显示空间。可以通过formatter_class类来修改它的显示方式。

5. add_argument()方法

函数参数简介

  • 函数参数
    • name or flags - Either a name or a list of option strings, e.g. foo or -f, --foo.
      • 参数名或者参数标识
      • -的为可选参数(optional parameter)
      • 不带-的为必选参数(positional parametrer)。
    • action - The basic type of action to be taken when this argument is encountered at the command line.
      • 参数的处理方法
    • nargs - The number of command-line arguments that should be consumed.
      • 参数的数量
    • const - A constant value required by some action and nargs selections.
      • 参数的常量值
    • default - The value produced if the argument is absent from the command line and if it is absent from the namespace object.
      • 参数默认值
    • type - The type to which the command-line argument should be converted.
      • 参数的数据类型
    • choices - A container of the allowable values for the argument.
      • 参数取值范围
    • required - Whether or not the command-line option may be omitted (optionals only).
      • 参数是否可以忽略不写 ,仅对可选参数有效
    • help - A brief description of what the argument does.
      • 参数的说明信息
    • metavar - A name for the argument in usage messages.
      • 参数在说明信息usage中的名称
    • dest - The name of the attribute to be added to the object returned by parse_args().
      • 对象的属性名

常用函数参数具体用法

  • default:默认参数值

    """程序名称:myprogram.py"""
    import argparse
    
    def get_parser():
        parser = argparse.ArgumentParser(description="Demo of argparse")
        parser.add_argument('--name'
                            , default='Great'   # 默认值为Great
                           )
    
        return parser
    
    if __name__ == '__main__':
        parser = get_parser()
        args = parser.parse_args()
        name = args.name
        print('Hello {}'.format(name))
    
    

    终端运行结果

    $ python myprogram.py 
    Hello Great
    $ python myprogram.py  --name mengzhichao
    Hello mengzhichao
    
  • required:表示这个参数是否一定要设置

    如果设置了required=True,则在实际运行的时候不设置该参数将报错(即使在你已经指定了default的情况下):

    $ python myprogram.py 
    usage: myprogram.py [-h] --name NAME
    myprogram.py: error: the following arguments are required: --name
    $ python myprogram.py --name mengzhichao
    Hello mengzhichao
    
  • type:参数类型

    默认的参数类型是str类型,如果你的程序需要一个整数或者布尔型参数,你需要设置type=inttype=bool,下面是一个打印平方的例子:

    """程序名称:myprogram.py"""
    import argparse
    
    
    def get_parser():
        parser = argparse.ArgumentParser(
            description='Calculate square of a given number')
        parser.add_argument('-number'
                            , type=int # 参数类型
                            , required=True # 该参数是否一定需要
                            , default=10 # 参数默认值
                           )
    
        return parser
    
    
    if __name__ == '__main__':
        parser = get_parser()
        args = parser.parse_args()
        res = args.number ** 2
        print('square of {} is {}'.format(args.number, res))
    

    如果不是指定的参数类型,那么会报错

    $ python myprogram.py -number 10
    square of 10 is 100
    $ python myprogram.py -number 'str'
    usage: myprogram.py [-h] -number NUMBER
    myprogram.py: error: argument -number: invalid int value: 'str'
    
  • choices: 参数值只能从几个选项里面选择

    import argparse
    
    
    def get_parser():
        parser = argparse.ArgumentParser(
            description='choices demo')
        parser.add_argument('-arch'
                            , required=True
                            , choices=['alexnet', 'vgg']
                            , type=str)
    
        return parser
    
    
    if __name__ == '__main__':
        parser = get_parser()
        args = parser.parse_args()
        print('the arch of CNN is {}'.format(args.arch))
    

    如果向下面第一条这样执行会报错,因为第一条所给的-arch参数resnet不在备选的choices之中,所以会报错:

    $ python myprogram.py -arch resnet
    usage: myprogram.py [-h] -arch {alexnet,vgg}
    myprogram.py: error: argument -arch: invalid choice: 'resnet' (choose from 'alexnet', 'vgg')
    
    $ python myprogram.py -arch alexnet
    the arch of CNN is alexnet
    
  • help:指定参数的说明信息

    import argparse
    
    def get_parser():
        parser = argparse.ArgumentParser(
            description='help demo')
        parser.add_argument('-arch'
                            , required=True
                            , choices=['alexnet', 'vgg']
                            , help='the architecture of CNN, \
                            at this time we only support alexnet and vgg.')
    
    
        return parser
    
    
    if __name__ == '__main__':
        parser = get_parser()
        args = parser.parse_args()
    
        print('the arch of CNN is {}'.format(args.arch))
    

    在命令行加-h--help参数运行该命令,获取帮助信息的时候,结果如下:

    $ python myprogram.py --help
    usage: myprogram.py [-h] -arch {alexnet,vgg}
    
    help demo
    
    optional arguments:
      -h, --help           show this help message and exit
      -arch {alexnet,vgg}  the architecture of CNN, at this time we only support
                           alexnet and vgg.
    
    
  • dest:设置参数在代码中的别名—被添加到 parse_args() 所返回对象上的属性名:

    注意:为该参数设置别名后,只能用别名在namespace中调用。(参考下述代码8 22 行)

    import argparse
    
    
    def get_parser():
        parser = argparse.ArgumentParser(
            description='help demo')
        parser.add_argument('-arch'
                            , dest='ar'
                            , required=False
                            , default='alexnet'
                            , choices=['alexnet', 'vgg']
                            , help='the architecture of CNN, \
                            at this time we only support alexnet and vgg.')
    
        return parser
    
    
    if __name__ == '__main__':
        parser = get_parser()
        args = parser.parse_args()
        
        print('the arch of CNN is {}'.format(args.ar))
    
  • nargs:设置参数在使用时可以提供的个数

    使用方式如下

    parser.add_argument(‘-name’, nargs=x)

    其中x的候选值和含义如下:

    含义
    N参数的绝对个数(如:3)
    ?0或者1个参数
    *0或所有参数
    +所有,并且至少一个参数

    如下例子:

    import argparse
    
    
    def get_parser():
        parser = argparse.ArgumentParser(
            description='nargs demo')
        parser.add_argument('-name'
                            , type=str  # 类型
                            , dest='nas'  # 别名
                            , metavar='names'  # 在usage中显示的名称
                            , required=False  # 是否一定需要
                            , default='A'  # 默认
                            , choices=['A', 'B', 'C'] # 参数选择范围
                            , nargs='+'
                            , )
    
        return parser
    
    
    if __name__ == '__main__':
        parser = get_parser()
        args = parser.parse_args()
        names = ', '.join(args.nas)
        print('Hello to {}'.format(names))
    
    

    执行命令和结果如下:

    $ python myprogram.py -name A B C
    Hello to A, B, C
    
    $ python myprogram.py -name A g
    usage: myprogram.py [-h] [-name names [names ...]]
    myprogram.py: error: argument -name: invalid choice: 'g' (choose from 'A', 'B', 'C')
    

6. parse_args()方法

  • ArgumentParser.parse_args(args=None, namespace=None)
    

    parse_args()方法将命令行参数字符串转换为相应对象并赋值给Namespace对象的相应属性,默认返回一个Namespace对象。

    1. args: 字符串列表,默认来自sys.argv。
    2. namespace:对象名,默认是一个空的namespace对象。

7. namespace 对象

  • class argparse.Namespace
    

    调用parse_args()的返回值是一个Namespace对象,它具有很多属性,每个属性都代表相应的命令行参数。

    Namespace对象是一个非常简单的类,可以通过vars()将之转换成字典类型。例如:

    >>> import argparse
    >>> parser = argparse.ArgumentParser()
    >>> parser.add_argument('--test')
    _StoreAction(option_strings=['--test'], dest='test', nargs=None, const=None, default=None, type=None, choices=None, help=None, metavar=None)
    >>> args = parser.parse_args(['--test', 'TEST'])
    >>> args
    Namespace(test='TEST')
    >>> vars(args)
    {'test': 'TEST'}
    
  • 另外还可以将ArgumentParser对象赋值给别的命令空间,而不是新建一个Namespace对象,例如:

    $ class C:
    ...     pass
    ...
    $ c = C()
    $ parser = argparse.ArgumentParser()
    $ parser.add_argument('--foo')
    $ parser.parse_args(args=['--foo', 'BAR'], namespace=c)
    $ c.foo
    'BAR'
    

8.argparse 实例(以决策树参数、随机森林参数为例)

直接添加

  • 主体

    def get_parser():
        """
        为决策树算法提供参数
        :return: 参数字典
        """
        parser = argparse.ArgumentParser(
            description='为决策树算法提供参数'  # 程序目的说明
        )
        parser.add_argument('--criterion'
                            , required=False
                            , default='entropy'
                            , choices=['gini', 'entropy']
                            , type=str
                            , help='决策树节点分裂准则选择')
        parser.add_argument('--splitter'
                            , required=False
                            , default='best'
                            , choices=['best', 'random']
                            , type=str
                            , help='决策树随机/最优选择特征划分')
        parser.add_argument('--max_depth'
                            , required=False
                            , default=None
                            , type=int
                            , help='决策树最大深度')
        parser.add_argument('--min_samples_split'
                            , required=False
                            , default=2
                            , type=int
                            , help='决策树父节点最少划分样本量')
        parser.add_argument('--min_samples_leaf'
                            , required=False
                            , default=2
                            , type=int
                            , help='决策树叶子节点最少划分样本量')
        parser.add_argument('--min_weight_fraction_leaf'
                            , required=False
                            , default=0
                            , type=int
                            , help='决策树忘了这是啥')
        parser.add_argument('--max_leaf_nodes'
                            , required=False
                            , default=None
                            , type=int
                            , help='决策树最大叶子节点数')
        parser.add_argument('--random_state'
                            , required=False
                            , default=10
                            , type=int
                            , help='随机模式参数')
        parser.add_argument('--max_features'
                            , required=False
                            , default=None
                            , type=int
                            , help='决策树训练用最大特征数量')
        parser = parser.parse_args()
        print(vars(parser))
        return vars(parser)
    
  • 查看help

    $ python 10、Package_argparse.py --help
    usage: 10、Package_argparse.py [-h] [--criterion {gini,entropy}]
                                  [--splitter {best,random}]
                                  [--max_depth MAX_DEPTH]
                                  [--min_samples_split MIN_SAMPLES_SPLIT]
                                  [--min_samples_leaf MIN_SAMPLES_LEAF]
                                  [--min_weight_fraction_leaf MIN_WEIGHT_FRACTION_LEAF]
                                  [--max_leaf_nodes MAX_LEAF_NODES]
                                  [--random_state RANDOM_STATE]
                                  [--max_features MAX_FEATURES]
    
    为决策树算法提供参数
    
    optional arguments:
      -h, --help            show this help message and exit
      --criterion {gini,entropy}
                            决策树节点分裂准则选择
      --splitter {best,random}
                            决策树随机/最优选择特征划分
      --max_depth MAX_DEPTH
                            决策树最大深度
      --min_samples_split MIN_SAMPLES_SPLIT
                            决策树父节点最少划分样本量
      --min_samples_leaf MIN_SAMPLES_LEAF
                            决策树叶子节点最少划分样本量
      --min_weight_fraction_leaf MIN_WEIGHT_FRACTION_LEAF
                            决策树忘了这是啥
      --max_leaf_nodes MAX_LEAF_NODES
                            决策树最大叶子节点数
      --random_state RANDOM_STATE
                            随机模式参数
      --max_features MAX_FEATURES
                            决策树训练用最大特征数量
    
    
  • 全部代码

    # -*- coding: utf-8 -*-
    
    """
    **************************************************
    @author:   Ying                                      
    @software: PyCharm                       
    @file: 10、Package_argparse.py
    @time: 2021-07-06 09:23                          
    **************************************************
    """
    import argparse
    from sklearn.datasets import load_wine, load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.metrics import classification_report
    from sklearn.metrics import plot_roc_curve, plot_precision_recall_curve, plot_confusion_matrix
    from sklearn.tree import DecisionTreeClassifier
    import pandas as pd
    from matplotlib import pyplot as plt
    import os
    import sys
    
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams["font.family"] = 'Arial Unicode MS'
    
    global cur_file_path
    cur_file_path = os.path.split(sys.argv[0])[0]
    
    
    def get_parser():
        """
        为决策树算法提供参数
        :return: 参数字典
        """
        parser = argparse.ArgumentParser(
            description='为决策树算法提供参数'  # 程序目的说明
        )
        parser.add_argument('--criterion'
                            , required=False
                            , default='entropy'
                            , choices=['gini', 'entropy']
                            , type=str
                            , help='决策树节点分裂准则选择')
        parser.add_argument('--splitter'
                            , required=False
                            , default='best'
                            , choices=['best', 'random']
                            , type=str
                            , help='决策树随机/最优选择特征划分')
        parser.add_argument('--max_depth'
                            , required=False
                            , default=None
                            , type=int
                            , help='决策树最大深度')
        parser.add_argument('--min_samples_split'
                            , required=False
                            , default=2
                            , type=int
                            , help='决策树父节点最少划分样本量')
        parser.add_argument('--min_samples_leaf'
                            , required=False
                            , default=2
                            , type=int
                            , help='决策树叶子节点最少划分样本量')
        parser.add_argument('--min_weight_fraction_leaf'
                            , required=False
                            , default=0
                            , type=int
                            , help='决策树忘了这是啥')
        parser.add_argument('--max_leaf_nodes'
                            , required=False
                            , default=None
                            , type=int
                            , help='决策树最大叶子节点数')
        parser.add_argument('--random_state'
                            , required=False
                            , default=10
                            , type=int
                            , help='随机模式参数')
        parser.add_argument('--max_features'
                            , required=False
                            , default=None
                            , type=int
                            , help='决策树训练用最大特征数量')
        parser = parser.parse_args()
        print(vars(parser))
        return vars(parser)
    
    
    class DT:
        def __init__(self, criterion, splitter, max_depth
                     , min_samples_split, min_samples_leaf
                     , min_weight_fraction_leaf, max_leaf_nodes
                     , random_state, max_features, data=None
                     , X=None
                     , y=None):
            """
            初始化函数
            :param criterion: 决策树节点分裂准则选择
            :param splitter: 决策树随机/最优选择特征划分
            :param max_depth: 决策树最大深度
            :param min_samples_split: 决策树父节点最少划分样本量
            :param min_samples_leaf: 决策树叶子节点最少划分样本量
            :param min_weight_fraction_leaf: 决策树忘了这是啥
            :param max_leaf_nodes: 决策树最大叶子节点数
            :param random_state: 随机模式参数
            :param max_features: 决策树训练用最大特征数量
            :param data: 数据集
            :param X: X
            :param y: y
            """
            self.criterion = parameters[criterion]
            self.splitter = parameters[splitter]
            self.max_depth = parameters[max_depth]
            self.min_samples_split = parameters[min_samples_split]
            self.min_samples_leaf = parameters[min_samples_leaf]
            self.min_weight_fraction_leaf = parameters[min_weight_fraction_leaf]
            self.max_leaf_nodes = parameters[max_leaf_nodes]
            self.random_state = parameters[random_state]
            self.max_features = parameters[max_features]
    
            self.data = data  # 传进来的data包含标签列(y)
            self.feature_names = self.data.columns[:-1] if self.data is not None else None
            self.X = X
            self.y = y
    
        def create_datas(self):
            '''创建数据集'''
            wine = load_wine()
            self.X = wine.data
            self.y = wine.target
            self.feature_names = wine.feature_names
    
            self.data = pd.DataFrame(self.X, columns=self.feature_names)
            self.data['target'] = pd.DataFrame(self.y)
    
        def data_split(self):
            """训练集、测试集划分"""
            return train_test_split(self.X
                                    , self.y
                                    , test_size=0.3
                                    , random_state=self.random_state)
    
        def Modeling(self, X_train, X_test, y_train, y_test):
            """
            训练决策树
            """
            clf = DecisionTreeClassifier(criterion=self.criterion
                                         , splitter=self.splitter
                                         , max_depth=self.max_depth
                                         , min_samples_split=self.min_samples_split
                                         , min_samples_leaf=self.min_samples_leaf
                                         , min_weight_fraction_leaf=self.min_weight_fraction_leaf
                                         , max_leaf_nodes=self.max_leaf_nodes
                                         , random_state=self.random_state
                                         , max_features=self.max_features
                                         )
            clf.fit(X_train, y_train)
            score = clf.score(X_test, y_test)  # accuracy
            print(f'decision tree accuracy:{score}')
    
            y_test_predict = clf.predict(X_test)  # 测试集预测结果
            report_of_classifier = classification_report(y_test, y_test_predict)
            print(report_of_classifier)
    
            try:
                # roc曲线 and p_r曲线 and 混淆矩阵
    
                figure = plt.figure(num=1, figsize=(12, 8), dpi=300)
                ax1 = plt.subplot(212)
                plot_confusion_matrix(clf, X_test, y_test, ax=ax1)
                ax1.set_title('混淆矩阵')
                ax1.margins(2, 2)
    
                ax2 = plt.subplot(221)
                plot_roc_curve(clf, X_test, y_test, ax=ax2)
                ax2.set_title('ROC 曲线')
    
                ax3 = plt.subplot(222)
                plot_precision_recall_curve(clf, X_test, y_test, ax=ax3)
                ax3.set_title('P_R 曲线')
    
                # 当前程序文件名称(去除.py)
                cur_file_name = os.path.split(sys.argv[0])[1].split('.')[0]
                if not os.path.exists('save'):
                    os.mkdir('save')
    
                # 图片存储路径
                photo_save_path = os.path.join(cur_file_path, 'save', cur_file_name + '.png')
    
                figure.savefig(photo_save_path)
                plt.show()
    
            except ValueError:
                print('DecisionTreeClassifier should be a binary classifier')
    
        def run(self):
            if self.data is None:
                self.create_datas()
            else:
                self.X = self.data.iloc[:, :-1]
                self.y = self.data.iloc[:, -1]
    
            X_train, X_test, y_train, y_test = self.data_split()
            self.Modeling(X_train, X_test, y_train, y_test)
    
    
    if __name__ == '__main__':
        # 创建乳腺癌数据集
        breast = load_breast_cancer()
        data = pd.DataFrame(breast.data, columns=breast.feature_names)
        data['target'] = pd.DataFrame(breast.target)
    
        parameters = get_parser()
        dt = DT(*parameters, data)
        # dt = DT(*parameters)
    
        dt.run()
    
    
  • 可以直接运行,因为参数全都非必须,且都有默认值

    $ python 10、Package_argparse.py 
    {'criterion': 'entropy', 'splitter': 'best', 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 2, 'min_weight_fraction_leaf': 0, 'max_leaf_nodes': None, 'random_state': 10, 'max_features': None}
    decision tree accuracy:0.9532163742690059
                  precision    recall  f1-score   support
    
               0       0.89      0.98      0.94        59
               1       0.99      0.94      0.96       112
    
        accuracy                           0.95       171
       macro avg       0.94      0.96      0.95       171
    weighted avg       0.96      0.95      0.95       171
    
    

    也可以命令行指定参数运行

    $ python 10、Package_argparse.py --criterion gini
    {'criterion': 'gini', 'splitter': 'best', 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 2, 'min_weight_fraction_leaf': 0, 'max_leaf_nodes': None, 'random_state': 10, 'max_features': None}
    decision tree accuracy:0.9298245614035088
                  precision    recall  f1-score   support
    
               0       0.87      0.93      0.90        59
               1       0.96      0.93      0.95       112
    
        accuracy                           0.93       171
       macro avg       0.92      0.93      0.92       171
    weighted avg       0.93      0.93      0.93       171
    
    

通过配置文件(以随机森林为例,两个参数做测试)

  • 主体:

    配置文件

    n_estimators:10
    criterion:gini
    

    主体代码

    将配置文件中的参数读取,设置为默认(defualt),所以也可以手动改参数

    def read_conf(cls, conf_name):
        """"""
        conf_items = {}
        with open(conf_name) as f:
            line = f.readline()  # 单行读取
            while line:
                # 忽略空行和注视
                if len(line.strip()) > 3 and (not line.strip().startswith('#')):
                    kv_list = line.split(':')
                    key = kv_list[0].strip()
                    value = kv_list[1].strip()
                    conf_items[key] = value
                line = f.readline()
        return conf_items
    
    
    def get_parser():
        """
        usage : 获取参数信息
    
        process_path : 当前程序所在目录
        conf_info_path : 参数配置文件地址
        conf_dict : 处理完成返回的参数字典
    
        :return: 参数字典
        """
    
        conf_info_path = os.path.join(process_path, 'need/randomforecast.conf')
    
        conf_dict = read_conf(conf_info_path)
    
        parse = argparse.ArgumentParser(description='为随机隋林算法提供参数')
        parse.add_argument('--n_estimators'
                           , required=False
                           , type=int
                           , default=conf_dict['n_estimators']
                           , help='随机森林 基分类器数量')
        parse.add_argument('--criterion'
                           , required=False
                           , type=str
                           , default=conf_dict['criterion']
                           , choices=['entropy', 'gini']
                           , help='随机森林 基分类器节点分裂准则')
    
        parse = parse.parse_args()
    
        parameters = vars(parse)
        print(parameters)
        return parameters
    
    
  • 全部代码

    # -*- coding: utf-8 -*-
    
    """
    **************************************************
    @author:   Ying                                      
    @software: PyCharm                       
    @file: GeneralTools.py
    @time: 2021-07-07 09:52                          
    **************************************************
    """
    
    
    class GeneralTools:
        @classmethod
        def read_conf(cls, conf_name):
            """"""
            conf_items = {}
            with open(conf_name) as f:
                line = f.readline()  # 单行读取
                while line:
                    # 忽略空行和注视
                    if len(line.strip()) > 3 and (not line.strip().startswith('#')):
                        kv_list = line.split(':')
                        key = kv_list[0].strip()
                        value = kv_list[1].strip()
                        conf_items[key] = value
                    line = f.readline()
            return conf_items
    
    
    # -*- coding: utf-8 -*-
    
    """
    **************************************************
    @author:   Ying                                      
    @software: PyCharm                       
    @file: 2、RandomForecastClassifier.py
    @time: 2021-07-07 09:50                          
    **************************************************
    """
    
    import os
    import sys
    import argparse
    import pandas as pd
    from GeneralTools import GeneralTools
    from sklearn.datasets import load_breast_cancer
    from sklearn.model_selection import train_test_split
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.preprocessing import LabelEncoder
    from sklearn.metrics import classification_report
    from matplotlib import pyplot as plt
    from sklearn.metrics import plot_confusion_matrix, plot_precision_recall_curve, plot_roc_curve, plot_det_curve
    
    global process_path
    process_path = os.path.split(sys.argv[0])[0]
    
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams["font.family"] = 'Arial Unicode MS'
    
    
    def get_parser():
        """
        usage : 获取参数信息
    
        process_path : 当前程序所在目录
        conf_info_path : 参数配置文件地址
        conf_dict : 处理完成返回的参数字典
    
        :return: 参数字典
        """
    
        conf_info_path = os.path.join(process_path, 'need/randomforecast.conf')
    
        conf_dict = GeneralTools().read_conf(conf_info_path)
    
        parse = argparse.ArgumentParser(description='为随机隋林算法提供参数')
        parse.add_argument('--n_estimators'
                           , required=False
                           , type=int
                           , default=conf_dict['n_estimators']
                           , help='随机森林 基分类器数量')
        parse.add_argument('--criterion'
                           , required=False
                           , type=str
                           , default=conf_dict['criterion']
                           , choices=['entropy', 'gini']
                           , help='随机森林 基分类器节点分裂准则')
    
        parse = parse.parse_args()
    
        parameters = vars(parse)
        print(parameters)
        return parameters
    
    
    class RF:
        def __init__(self, n_estimators, criterion, data=None):
            self.criterion = parameters[criterion]
            self.n_estimators = parameters[n_estimators]
    
            self.data = data
            self.feature_names = self.data.columns if self.data is not None else None
            self.X = None
            self.y = None
    
        def transform_target(self, target):
            """仅对疫情target字段进行处理"""
            if target >= -0.2096:
                return '免疫力等级高'
            return '免疫力等级低'
    
        def data_preprocessing(self):
            """仅对疫情数据集进行处理"""
            self.data.drop(['企业名称', '分类', '企业性质_2'], axis=1, inplace=True)
            self.data.columns = ['industry', 'iso_num', 'qiye_nature', 'LC_performance',
                                 'shareholder_num', 'shareholder_com', 'zl_zscq', 'ywfb',
                                 'clsj', 'zczj', 'sjzj', 'person_num', 'difference']
            self.data['target'] = data['difference'].apply(self.transform_target)
            self.data.drop('difference', axis=1)
    
            le = LabelEncoder()
            data['target'] = le.fit_transform(data['target'])
            modeling_data = data.join(pd.get_dummies(data['shareholder_com'])).join(pd.get_dummies(data['industry'])).join(
                pd.get_dummies(data['qiye_nature'])).drop(['industry', 'qiye_nature', 'shareholder_com']
                                                          , axis=1)
            target_data = modeling_data['target']
            modeling_data.drop('target', axis=1, inplace=True)
            modeling_data = modeling_data.join(target_data)
    
            self.data = modeling_data
            self.X = self.data.iloc[:, :-1]
            self.y = self.data.iloc[:, -1]
    
        def split_data(self):
            return train_test_split(self.X
                                    , self.y
                                    , random_state=120
                                    , test_size=0.3)
    
        def create_data(self):
            breast = load_breast_cancer()
            self.feature_names = breast.feature_names
            self.data = pd.DataFrame(breast.data, columns=self.feature_names)
            self.data['target'] = pd.Series(breast.target)
    
            self.X = breast.data
            self.y = breast.target
    
        def modelling(self, X_train, X_test, y_train, y_test):
            rf_clf = RandomForestClassifier(n_estimators=self.n_estimators
                                            , criterion=self.criterion)
            rf_clf.fit(X_train, y_train)
            acc = rf_clf.score(X_test, y_test)
    
            print(f"accuracy:{acc}")
    
            y_test_pred = rf_clf.predict(X_test)
    
            report_result = classification_report(y_test, y_test_pred)
            print(report_result)
    
            try:
                fig, axes = plt.subplots(2, 2)
                fig.set_size_inches(20, 12)
                fig.set_dpi(300)
    
                plot_roc_curve(rf_clf, X_test, y_test, ax=axes[0, 0])
                axes[0, 0].set_title('roc 曲线')
                plot_precision_recall_curve(rf_clf, X_test, y_test, ax=axes[0, 1])
                axes[0, 1].set_title('P_R 曲线')
                plot_confusion_matrix(rf_clf, X_test, y_test, ax=axes[1, 0])
                axes[1, 0].set_title('混淆矩阵')
                plot_det_curve(rf_clf, X_test, y_test, ax=axes[1, 1])
                axes[1, 1].set_title('det 曲线')
    
                plt.show()
            except:
                pass
    
        def run(self):
            if self.data is None:
                self.create_data()
            else:
                self.data_preprocessing()
    
            X_train, X_test, y_train, y_test = self.split_data()
    
            self.modelling(X_train, X_test, y_train, y_test)
    
    
    if __name__ == '__main__':
        parameters = get_parser()
    
        data = pd.read_excel(os.path.join(process_path, 'need/疫情数据集.xlsx'))
    
        rf = RF(*parameters, data)
        rf.run()
    

9. 封装

  • 通用类

    增加参数在 initialize函数中增加

    import argparse
    import logging
    
    class BaseOptions:
        """BaseOptions:   定义参数使用 
        
        Examples
        -------------------------------------------------------------------------------------------------------------------
        # 获取所有参数
        >>> from options import BaseOptions
        >>> args = BaseOptions().parse()
        >>> # 取参数
        >>> args.参数名
        >>> # 新增参数
        >>> args.新参数名=参数值
        >>> # 修改参数
        >>> args.已有参数=修改的参数值
        >>> 
        """
        def __init__(self):
            """__init__:    初始化类,该类尚未初始化
            """
            self.initialized = False
    
        def initialize(self,parser):
            """initialize:   定义通用的参数 
            Parameters:
            -------------------------------------------------------------------------------------------------------------------
            >>> parser.add_argument('--dataroot'            # 变量名
                                  , type=str                # 读取参数的类型
                                  , required=False          # 可选项是否必须有
                                  , default='a'             # 默认值,类型与type统一
                                  , choices=['a','b']       # 参数值只能从几个选项里面选择
                                  , help='11111'            # 该参数说明
                                  )
            -------------------------------------------------------------------------------------------------------------------
            """
            parser.add_argument('--dataroot'       # 变量名
                              , type=str                # 读取参数的类型
                              , required=False          # 可选项是否必须有
                              , default='a'             # 默认值,类型与type统一
                              , choices=['a','b']       # 参数值只能从几个选项里面选择
                              , help='11111'            # 该参数说明
                             )
            self.initialized = True
            return parser
    
        def gather_options(self):
            """gather_options:   初始化通用的参数(只初始化一次) 
            """
            if not self.initialized:
                parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
                parser = self.initialize(parser)
            return parser.parse_args()      
    
        def parse(self):
            """parse:   解析参数 
            """
            opt = self.gather_options()
            self.opt = opt
            return self.opt
        
    

  • 9
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值