文章目录
1. argparse模块简介
-
argparse是Python中的一个常用模块,和sys.argv()功能类似,主要用于编写命令行接口:对于程序所需要的参数,它可以进行正确的解析。另外,argparse还可以自动的生成help和 usage信息,当程序的参数无效时,它可以自动生成错误信息。
-
https://docs.python.org/3/library/argparse.html
2. argparse例子
-
下列代码是python程序,它可以接受多个整数,并返回它们的和或者最大值。
import argparse parser = argparse.ArgumentParser(description='处理一些整数') parser.add_argument('integers', metavar='N', type=int, nargs='+', help='累加器中的整数') parser.add_argument('--sum', dest='accumulate', action='store_const', const=sum, default=max, help='对整数进行求和(默认: 寻找最大值)') args = parser.parse_args() print(args.accumulate(args.integers))
假定上述代码保存为test.py,它可以在命令行执行并自动提供有用的
help
信息:(base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python test.py -h usage: test.py [-h] [--sum] N [N ...] 处理一些整数 positional arguments: N 累加器中的整数 optional arguments: -h, --help show this help message and exit --sum 对整数进行求和(默认: 寻找最大值)
该程序可以接受相应的参数并给出相应的输出:
(base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python test.py 1 2 3 4 4 (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python test.py 1 2 3 4 --sum 10
如果传入无效的参数,它可以自动生成`error`信息: ```bash (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python test.py a b usage: test.py [-h] [--sum] N [N ...] test.py: error: argument N: invalid int value: 'a'
3. argparse 的三个主要函数
Creating a parser:
-
parser=argparse.ArgumentParser()
使用argparse的第一步是创建ArgumentParser对象:
parser = argparse.ArgumentParser(description='处理一些整数')
ArgumentParser对象保存了所有必要的信息,用以将命令行参数解析为相应的
python
数据类型。
Adding arguments:
-
parser.add_argument()
调用add_argument()向ArgumentParser对象添加命令行参数信息,这些信息告诉ArgumentParser对象如何处理命令行参数。可以通过调用parse_agrs()来使用这些命令行参数。例如:
parser.add_argument('integers', metavar='N', type=int, nargs='+', help='累加器中的整数') parser.add_argument('--sum', dest='accumulate', action='store_const', const=sum, default=max, help='对整数进行求和(默认: 寻找最大值)')
之后,调用parse.args()将返回一个对象,它具有两个属性:
integers
和accumulate
,前者是一个整数集合,后者根据命令行参数的不同,代表不同的函数,可以是max(),也可以是sum()。
Parsing arguments:
-
args = parser.parse_args()
通过调用parse_args()来解析ArgumentParser对象中保存的命令行参数:将命令行参数解析成相应的数据类型并采取相应的动作,它返回一个
Namespace
对象。>>> parser.parse_args(['--sum', '7', '-1', '42']) Namespace(accumulate=<built-in function sum>, integers=[7, -1, 42])
在实际的
python
脚本中,parse_args()一般并不使用参数,它的参数由sys.argv
决定。
4. ArgumentParser对象(ArgumentParser objects)
class argparse.ArgumentParser(prog=None
, usage=None
, description=None
, epilog=None
, parents=[]
, formatter_class=argparse.HelpFormatter
, argument_default=None
, conflict_handler=’error’
, add_help=True
, allow_abbrev=True
)
对象参数简介
- 创建一个新的 ArgumentParser 对象。所有参数都应作为关键字参数传递。每个参数在下面都有更详细的描述,但简而言之,它们是:
- prog - The name of the program (default:
sys.argv[0]
)- 程序文件名
- usage - The string describing the program usage (default: generated from arguments added to parser)
- 程序使用说明
- description - Text to display before the argument help (default: none)
- 程序目的说明
- epilog - Text to display after the argument help (default: none)
- 程序说明后记
- parents - A list of
ArgumentParser
objects whose arguments should also be included- ArgumentParser对象的父对象的参数列表。
- formatter_class - A class for customizing the help output
help
信息的说明格式
- prefix_chars - The set of characters that prefix optional arguments (default: ‘-‘)
- 命令行参数前缀。
- fromfile_prefix_chars - The set of characters that prefix files from which additional arguments should be read (default:
None
) - argument_default - The global default value for arguments (default:
None
)- 参数全局默认值
- conflict_handler - The strategy for resolving conflicting optionals (usually unnecessary)
- 冲突处理
- add_help - Add a
-h/--help
option to the parser (default:True
)- 是否增加help选项
- allow_abbrev - Allows long options to be abbreviated if the abbreviation is unambiguous. (default:
True
)- 是否使用参数缩写
- exit_on_error - Determines whether or not ArgumentParser exits with error info when an error occurs. (default:
True
)
- prog - The name of the program (default:
常用对象参数具体用法
-
prog
(程序文件名):默认情况下,ArgumentParser对象使用sys.argv[0]来决定如何在
help
信息中显示程序的文件名。
例如:下列代码的文件名为myprogram.pyimport argparse parser = argparse.ArgumentParser() parser.add_argument('--foo', help='foo help') args = parser.parse_args()
它的
help
信息将显示程序的文件名,而不管程序在何处进行调用。(base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ python myprogram.py --help usage: myprogram.py [-h] [--foo FOO] optional arguments: -h, --help show this help message and exit --foo FOO foo help (base) mengzhichaodeMacBook-Pro:Sample_Code mengzhichao$ cd .. (base) mengzhichaodeMacBook-Pro:16、python代码 mengzhichao$ python Sample_Code/myprogram.py --help usage: myprogram.py [-h] [--foo FOO] optional arguments: -h, --help show this help message and exit --foo FOO foo help
使用prog参数可以修改ArgumentParser对象的默认文件名:
$parser = argparse.ArgumentParser(prog='myprogram') $parser.print_help() usage: myprogram [-h] optional arguments: -h, --help show this help message and exit
不论程序的文件名是来自sys.argv[0],不是来自参数prog,都可以通过%(prog)s来引用程序的文件名。
$ parser = argparse.ArgumentParser(prog='myprogram') $ parser.add_argument('--foo', help='foo of the %(prog)s program') $ parser.print_help() usage: myprogram [-h] [--foo FOO] optional arguments: -h, --help show this help message and exit --foo FOO foo of the myprogram program
-
usage
程序使用说明默认情况下,ArgumentParser对象会根据它所包括的参数来自动的生成使用说明。
$ parser = argparse.ArgumentParser(prog='PROG') $ parser.add_argument('--foo', nargs='?', help='foo help') $ parser.add_argument('bar', nargs='+', help='bar help') $ parser.print_help() usage: PROG [-h] [--foo [FOO]] bar [bar ...] positional arguments: bar bar help optional arguments: -h, --help show this help message and exit --foo [FOO] foo help
可以使用参数usage来修改程序的使用说明信息。
$ parser = argparse.ArgumentParser(prog='PROG', usage='%(prog)s [options]') $ parser.add_argument('--foo', nargs='?', help='foo help') $ parser.add_argument('bar', nargs='+', help='bar help') $ parser.print_help() usage: PROG [options] positional arguments: bar bar help optional arguments: -h, --help show this help message and exit --foo [FOO] foo help
-
description
程序目的说明一般情况下,ArgumentParser对象会使用参数description来说明程序的用途以及目的。
description
信息一般位于usage
信息和help
信息之间$ parser = argparse.ArgumentParser(description='A foo that bars') $ parser.print_help() usage: argparse.py [-h] A foo that bars optional arguments: -h, --help show this help message and exit
默认情况下,
description
会自动调整间距来适应显示空间。可以通过formatter_class
类来修改它的显示方式。 -
epilog 程序说明后记
一些程序喜欢在参数描述信息之后添加额外的信息,这些信息可以通过参数epilog来指定。
$ parser = argparse.ArgumentParser( ... description='A foo that bars', ... epilog="And that's how you'd foo a bar") $ parser.print_help() usage: argparse.py [-h] A foo that bars optional arguments: -h, --help show this help message and exit And that's how you'd foo a bar
和
description
类似,epilog
会自动调整间距来适应显示空间。可以通过formatter_class
类来修改它的显示方式。
5. add_argument()方法
函数参数简介
- 函数参数
- name or flags - Either a name or a list of option strings, e.g.
foo
or-f, --foo
.- 参数名或者参数标识
- 带
-
的为可选参数(optional parameter) - 不带
-
的为必选参数(positional parametrer)。
- action - The basic type of action to be taken when this argument is encountered at the command line.
- 参数的处理方法
- nargs - The number of command-line arguments that should be consumed.
- 参数的数量
- const - A constant value required by some action and nargs selections.
- 参数的常量值
- default - The value produced if the argument is absent from the command line and if it is absent from the namespace object.
- 参数默认值
- type - The type to which the command-line argument should be converted.
- 参数的数据类型
- choices - A container of the allowable values for the argument.
- 参数取值范围
- required - Whether or not the command-line option may be omitted (optionals only).
- 参数是否可以忽略不写 ,仅对可选参数有效
- help - A brief description of what the argument does.
- 参数的说明信息
- metavar - A name for the argument in usage messages.
- 参数在说明信息
usage
中的名称
- 参数在说明信息
- dest - The name of the attribute to be added to the object returned by
parse_args()
.- 对象的属性名
- name or flags - Either a name or a list of option strings, e.g.
常用函数参数具体用法
-
default
:默认参数值"""程序名称:myprogram.py""" import argparse def get_parser(): parser = argparse.ArgumentParser(description="Demo of argparse") parser.add_argument('--name' , default='Great' # 默认值为Great ) return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() name = args.name print('Hello {}'.format(name))
终端运行结果
$ python myprogram.py Hello Great $ python myprogram.py --name mengzhichao Hello mengzhichao
-
required
:表示这个参数是否一定要设置如果设置了
required=True
,则在实际运行的时候不设置该参数将报错(即使在你已经指定了default的情况下):$ python myprogram.py usage: myprogram.py [-h] --name NAME myprogram.py: error: the following arguments are required: --name $ python myprogram.py --name mengzhichao Hello mengzhichao
-
type
:参数类型默认的参数类型是str类型,如果你的程序需要一个整数或者布尔型参数,你需要设置
type=int
或type=bool
,下面是一个打印平方的例子:"""程序名称:myprogram.py""" import argparse def get_parser(): parser = argparse.ArgumentParser( description='Calculate square of a given number') parser.add_argument('-number' , type=int # 参数类型 , required=True # 该参数是否一定需要 , default=10 # 参数默认值 ) return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() res = args.number ** 2 print('square of {} is {}'.format(args.number, res))
如果不是指定的参数类型,那么会报错
$ python myprogram.py -number 10 square of 10 is 100 $ python myprogram.py -number 'str' usage: myprogram.py [-h] -number NUMBER myprogram.py: error: argument -number: invalid int value: 'str'
-
choices
: 参数值只能从几个选项里面选择import argparse def get_parser(): parser = argparse.ArgumentParser( description='choices demo') parser.add_argument('-arch' , required=True , choices=['alexnet', 'vgg'] , type=str) return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() print('the arch of CNN is {}'.format(args.arch))
如果向下面第一条这样执行会报错,因为第一条所给的
-arch
参数resnet
不在备选的choices
之中,所以会报错:$ python myprogram.py -arch resnet usage: myprogram.py [-h] -arch {alexnet,vgg} myprogram.py: error: argument -arch: invalid choice: 'resnet' (choose from 'alexnet', 'vgg') $ python myprogram.py -arch alexnet the arch of CNN is alexnet
-
help
:指定参数的说明信息import argparse def get_parser(): parser = argparse.ArgumentParser( description='help demo') parser.add_argument('-arch' , required=True , choices=['alexnet', 'vgg'] , help='the architecture of CNN, \ at this time we only support alexnet and vgg.') return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() print('the arch of CNN is {}'.format(args.arch))
在命令行加
-h
或--help
参数运行该命令,获取帮助信息的时候,结果如下:$ python myprogram.py --help usage: myprogram.py [-h] -arch {alexnet,vgg} help demo optional arguments: -h, --help show this help message and exit -arch {alexnet,vgg} the architecture of CNN, at this time we only support alexnet and vgg.
-
dest
:设置参数在代码中的别名—被添加到parse_args()
所返回对象上的属性名:注意:为该参数设置别名后,只能用别名在namespace中调用。(参考下述代码8 22 行)
import argparse def get_parser(): parser = argparse.ArgumentParser( description='help demo') parser.add_argument('-arch' , dest='ar' , required=False , default='alexnet' , choices=['alexnet', 'vgg'] , help='the architecture of CNN, \ at this time we only support alexnet and vgg.') return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() print('the arch of CNN is {}'.format(args.ar))
-
nargs
:设置参数在使用时可以提供的个数使用方式如下
parser.add_argument(‘-name’, nargs=x)
其中
x
的候选值和含义如下:值 含义 N 参数的绝对个数(如:3) ? 0或者1个参数 * 0或所有参数 + 所有,并且至少一个参数 如下例子:
import argparse def get_parser(): parser = argparse.ArgumentParser( description='nargs demo') parser.add_argument('-name' , type=str # 类型 , dest='nas' # 别名 , metavar='names' # 在usage中显示的名称 , required=False # 是否一定需要 , default='A' # 默认 , choices=['A', 'B', 'C'] # 参数选择范围 , nargs='+' , ) return parser if __name__ == '__main__': parser = get_parser() args = parser.parse_args() names = ', '.join(args.nas) print('Hello to {}'.format(names))
执行命令和结果如下:
$ python myprogram.py -name A B C Hello to A, B, C $ python myprogram.py -name A g usage: myprogram.py [-h] [-name names [names ...]] myprogram.py: error: argument -name: invalid choice: 'g' (choose from 'A', 'B', 'C')
6. parse_args()方法
-
ArgumentParser.parse_args(args=None, namespace=None)
parse_args()
方法将命令行参数字符串转换为相应对象并赋值给Namespace
对象的相应属性,默认返回一个Namespace
对象。args
: 字符串列表,默认来自sys.argv。namespace
:对象名,默认是一个空的namespace对象。
7. namespace 对象
-
class argparse.Namespace
调用parse_args()的返回值是一个
Namespace
对象,它具有很多属性,每个属性都代表相应的命令行参数。Namespace
对象是一个非常简单的类,可以通过vars()将之转换成字典类型。例如:>>> import argparse >>> parser = argparse.ArgumentParser() >>> parser.add_argument('--test') _StoreAction(option_strings=['--test'], dest='test', nargs=None, const=None, default=None, type=None, choices=None, help=None, metavar=None) >>> args = parser.parse_args(['--test', 'TEST']) >>> args Namespace(test='TEST') >>> vars(args) {'test': 'TEST'}
-
另外还可以将ArgumentParser对象赋值给别的命令空间,而不是新建一个
Namespace
对象,例如:$ class C: ... pass ... $ c = C() $ parser = argparse.ArgumentParser() $ parser.add_argument('--foo') $ parser.parse_args(args=['--foo', 'BAR'], namespace=c) $ c.foo 'BAR'
8.argparse 实例(以决策树参数、随机森林参数为例)
直接添加
-
主体
def get_parser(): """ 为决策树算法提供参数 :return: 参数字典 """ parser = argparse.ArgumentParser( description='为决策树算法提供参数' # 程序目的说明 ) parser.add_argument('--criterion' , required=False , default='entropy' , choices=['gini', 'entropy'] , type=str , help='决策树节点分裂准则选择') parser.add_argument('--splitter' , required=False , default='best' , choices=['best', 'random'] , type=str , help='决策树随机/最优选择特征划分') parser.add_argument('--max_depth' , required=False , default=None , type=int , help='决策树最大深度') parser.add_argument('--min_samples_split' , required=False , default=2 , type=int , help='决策树父节点最少划分样本量') parser.add_argument('--min_samples_leaf' , required=False , default=2 , type=int , help='决策树叶子节点最少划分样本量') parser.add_argument('--min_weight_fraction_leaf' , required=False , default=0 , type=int , help='决策树忘了这是啥') parser.add_argument('--max_leaf_nodes' , required=False , default=None , type=int , help='决策树最大叶子节点数') parser.add_argument('--random_state' , required=False , default=10 , type=int , help='随机模式参数') parser.add_argument('--max_features' , required=False , default=None , type=int , help='决策树训练用最大特征数量') parser = parser.parse_args() print(vars(parser)) return vars(parser)
-
查看help
$ python 10、Package_argparse.py --help usage: 10、Package_argparse.py [-h] [--criterion {gini,entropy}] [--splitter {best,random}] [--max_depth MAX_DEPTH] [--min_samples_split MIN_SAMPLES_SPLIT] [--min_samples_leaf MIN_SAMPLES_LEAF] [--min_weight_fraction_leaf MIN_WEIGHT_FRACTION_LEAF] [--max_leaf_nodes MAX_LEAF_NODES] [--random_state RANDOM_STATE] [--max_features MAX_FEATURES] 为决策树算法提供参数 optional arguments: -h, --help show this help message and exit --criterion {gini,entropy} 决策树节点分裂准则选择 --splitter {best,random} 决策树随机/最优选择特征划分 --max_depth MAX_DEPTH 决策树最大深度 --min_samples_split MIN_SAMPLES_SPLIT 决策树父节点最少划分样本量 --min_samples_leaf MIN_SAMPLES_LEAF 决策树叶子节点最少划分样本量 --min_weight_fraction_leaf MIN_WEIGHT_FRACTION_LEAF 决策树忘了这是啥 --max_leaf_nodes MAX_LEAF_NODES 决策树最大叶子节点数 --random_state RANDOM_STATE 随机模式参数 --max_features MAX_FEATURES 决策树训练用最大特征数量
-
全部代码
# -*- coding: utf-8 -*- """ ************************************************** @author: Ying @software: PyCharm @file: 10、Package_argparse.py @time: 2021-07-06 09:23 ************************************************** """ import argparse from sklearn.datasets import load_wine, load_breast_cancer from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report from sklearn.metrics import plot_roc_curve, plot_precision_recall_curve, plot_confusion_matrix from sklearn.tree import DecisionTreeClassifier import pandas as pd from matplotlib import pyplot as plt import os import sys plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False plt.rcParams["font.family"] = 'Arial Unicode MS' global cur_file_path cur_file_path = os.path.split(sys.argv[0])[0] def get_parser(): """ 为决策树算法提供参数 :return: 参数字典 """ parser = argparse.ArgumentParser( description='为决策树算法提供参数' # 程序目的说明 ) parser.add_argument('--criterion' , required=False , default='entropy' , choices=['gini', 'entropy'] , type=str , help='决策树节点分裂准则选择') parser.add_argument('--splitter' , required=False , default='best' , choices=['best', 'random'] , type=str , help='决策树随机/最优选择特征划分') parser.add_argument('--max_depth' , required=False , default=None , type=int , help='决策树最大深度') parser.add_argument('--min_samples_split' , required=False , default=2 , type=int , help='决策树父节点最少划分样本量') parser.add_argument('--min_samples_leaf' , required=False , default=2 , type=int , help='决策树叶子节点最少划分样本量') parser.add_argument('--min_weight_fraction_leaf' , required=False , default=0 , type=int , help='决策树忘了这是啥') parser.add_argument('--max_leaf_nodes' , required=False , default=None , type=int , help='决策树最大叶子节点数') parser.add_argument('--random_state' , required=False , default=10 , type=int , help='随机模式参数') parser.add_argument('--max_features' , required=False , default=None , type=int , help='决策树训练用最大特征数量') parser = parser.parse_args() print(vars(parser)) return vars(parser) class DT: def __init__(self, criterion, splitter, max_depth , min_samples_split, min_samples_leaf , min_weight_fraction_leaf, max_leaf_nodes , random_state, max_features, data=None , X=None , y=None): """ 初始化函数 :param criterion: 决策树节点分裂准则选择 :param splitter: 决策树随机/最优选择特征划分 :param max_depth: 决策树最大深度 :param min_samples_split: 决策树父节点最少划分样本量 :param min_samples_leaf: 决策树叶子节点最少划分样本量 :param min_weight_fraction_leaf: 决策树忘了这是啥 :param max_leaf_nodes: 决策树最大叶子节点数 :param random_state: 随机模式参数 :param max_features: 决策树训练用最大特征数量 :param data: 数据集 :param X: X :param y: y """ self.criterion = parameters[criterion] self.splitter = parameters[splitter] self.max_depth = parameters[max_depth] self.min_samples_split = parameters[min_samples_split] self.min_samples_leaf = parameters[min_samples_leaf] self.min_weight_fraction_leaf = parameters[min_weight_fraction_leaf] self.max_leaf_nodes = parameters[max_leaf_nodes] self.random_state = parameters[random_state] self.max_features = parameters[max_features] self.data = data # 传进来的data包含标签列(y) self.feature_names = self.data.columns[:-1] if self.data is not None else None self.X = X self.y = y def create_datas(self): '''创建数据集''' wine = load_wine() self.X = wine.data self.y = wine.target self.feature_names = wine.feature_names self.data = pd.DataFrame(self.X, columns=self.feature_names) self.data['target'] = pd.DataFrame(self.y) def data_split(self): """训练集、测试集划分""" return train_test_split(self.X , self.y , test_size=0.3 , random_state=self.random_state) def Modeling(self, X_train, X_test, y_train, y_test): """ 训练决策树 """ clf = DecisionTreeClassifier(criterion=self.criterion , splitter=self.splitter , max_depth=self.max_depth , min_samples_split=self.min_samples_split , min_samples_leaf=self.min_samples_leaf , min_weight_fraction_leaf=self.min_weight_fraction_leaf , max_leaf_nodes=self.max_leaf_nodes , random_state=self.random_state , max_features=self.max_features ) clf.fit(X_train, y_train) score = clf.score(X_test, y_test) # accuracy print(f'decision tree accuracy:{score}') y_test_predict = clf.predict(X_test) # 测试集预测结果 report_of_classifier = classification_report(y_test, y_test_predict) print(report_of_classifier) try: # roc曲线 and p_r曲线 and 混淆矩阵 figure = plt.figure(num=1, figsize=(12, 8), dpi=300) ax1 = plt.subplot(212) plot_confusion_matrix(clf, X_test, y_test, ax=ax1) ax1.set_title('混淆矩阵') ax1.margins(2, 2) ax2 = plt.subplot(221) plot_roc_curve(clf, X_test, y_test, ax=ax2) ax2.set_title('ROC 曲线') ax3 = plt.subplot(222) plot_precision_recall_curve(clf, X_test, y_test, ax=ax3) ax3.set_title('P_R 曲线') # 当前程序文件名称(去除.py) cur_file_name = os.path.split(sys.argv[0])[1].split('.')[0] if not os.path.exists('save'): os.mkdir('save') # 图片存储路径 photo_save_path = os.path.join(cur_file_path, 'save', cur_file_name + '.png') figure.savefig(photo_save_path) plt.show() except ValueError: print('DecisionTreeClassifier should be a binary classifier') def run(self): if self.data is None: self.create_datas() else: self.X = self.data.iloc[:, :-1] self.y = self.data.iloc[:, -1] X_train, X_test, y_train, y_test = self.data_split() self.Modeling(X_train, X_test, y_train, y_test) if __name__ == '__main__': # 创建乳腺癌数据集 breast = load_breast_cancer() data = pd.DataFrame(breast.data, columns=breast.feature_names) data['target'] = pd.DataFrame(breast.target) parameters = get_parser() dt = DT(*parameters, data) # dt = DT(*parameters) dt.run()
-
可以直接运行,因为参数全都非必须,且都有默认值
$ python 10、Package_argparse.py {'criterion': 'entropy', 'splitter': 'best', 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 2, 'min_weight_fraction_leaf': 0, 'max_leaf_nodes': None, 'random_state': 10, 'max_features': None} decision tree accuracy:0.9532163742690059 precision recall f1-score support 0 0.89 0.98 0.94 59 1 0.99 0.94 0.96 112 accuracy 0.95 171 macro avg 0.94 0.96 0.95 171 weighted avg 0.96 0.95 0.95 171
也可以命令行指定参数运行
$ python 10、Package_argparse.py --criterion gini {'criterion': 'gini', 'splitter': 'best', 'max_depth': None, 'min_samples_split': 2, 'min_samples_leaf': 2, 'min_weight_fraction_leaf': 0, 'max_leaf_nodes': None, 'random_state': 10, 'max_features': None} decision tree accuracy:0.9298245614035088 precision recall f1-score support 0 0.87 0.93 0.90 59 1 0.96 0.93 0.95 112 accuracy 0.93 171 macro avg 0.92 0.93 0.92 171 weighted avg 0.93 0.93 0.93 171
通过配置文件(以随机森林为例,两个参数做测试)
-
主体:
配置文件
n_estimators:10 criterion:gini
主体代码
将配置文件中的参数读取,设置为默认(defualt),所以也可以手动改参数
def read_conf(cls, conf_name): """""" conf_items = {} with open(conf_name) as f: line = f.readline() # 单行读取 while line: # 忽略空行和注视 if len(line.strip()) > 3 and (not line.strip().startswith('#')): kv_list = line.split(':') key = kv_list[0].strip() value = kv_list[1].strip() conf_items[key] = value line = f.readline() return conf_items def get_parser(): """ usage : 获取参数信息 process_path : 当前程序所在目录 conf_info_path : 参数配置文件地址 conf_dict : 处理完成返回的参数字典 :return: 参数字典 """ conf_info_path = os.path.join(process_path, 'need/randomforecast.conf') conf_dict = read_conf(conf_info_path) parse = argparse.ArgumentParser(description='为随机隋林算法提供参数') parse.add_argument('--n_estimators' , required=False , type=int , default=conf_dict['n_estimators'] , help='随机森林 基分类器数量') parse.add_argument('--criterion' , required=False , type=str , default=conf_dict['criterion'] , choices=['entropy', 'gini'] , help='随机森林 基分类器节点分裂准则') parse = parse.parse_args() parameters = vars(parse) print(parameters) return parameters
-
全部代码
# -*- coding: utf-8 -*- """ ************************************************** @author: Ying @software: PyCharm @file: GeneralTools.py @time: 2021-07-07 09:52 ************************************************** """ class GeneralTools: @classmethod def read_conf(cls, conf_name): """""" conf_items = {} with open(conf_name) as f: line = f.readline() # 单行读取 while line: # 忽略空行和注视 if len(line.strip()) > 3 and (not line.strip().startswith('#')): kv_list = line.split(':') key = kv_list[0].strip() value = kv_list[1].strip() conf_items[key] = value line = f.readline() return conf_items
# -*- coding: utf-8 -*- """ ************************************************** @author: Ying @software: PyCharm @file: 2、RandomForecastClassifier.py @time: 2021-07-07 09:50 ************************************************** """ import os import sys import argparse import pandas as pd from GeneralTools import GeneralTools from sklearn.datasets import load_breast_cancer from sklearn.model_selection import train_test_split from sklearn.ensemble import RandomForestClassifier from sklearn.preprocessing import LabelEncoder from sklearn.metrics import classification_report from matplotlib import pyplot as plt from sklearn.metrics import plot_confusion_matrix, plot_precision_recall_curve, plot_roc_curve, plot_det_curve global process_path process_path = os.path.split(sys.argv[0])[0] plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False plt.rcParams["font.family"] = 'Arial Unicode MS' def get_parser(): """ usage : 获取参数信息 process_path : 当前程序所在目录 conf_info_path : 参数配置文件地址 conf_dict : 处理完成返回的参数字典 :return: 参数字典 """ conf_info_path = os.path.join(process_path, 'need/randomforecast.conf') conf_dict = GeneralTools().read_conf(conf_info_path) parse = argparse.ArgumentParser(description='为随机隋林算法提供参数') parse.add_argument('--n_estimators' , required=False , type=int , default=conf_dict['n_estimators'] , help='随机森林 基分类器数量') parse.add_argument('--criterion' , required=False , type=str , default=conf_dict['criterion'] , choices=['entropy', 'gini'] , help='随机森林 基分类器节点分裂准则') parse = parse.parse_args() parameters = vars(parse) print(parameters) return parameters class RF: def __init__(self, n_estimators, criterion, data=None): self.criterion = parameters[criterion] self.n_estimators = parameters[n_estimators] self.data = data self.feature_names = self.data.columns if self.data is not None else None self.X = None self.y = None def transform_target(self, target): """仅对疫情target字段进行处理""" if target >= -0.2096: return '免疫力等级高' return '免疫力等级低' def data_preprocessing(self): """仅对疫情数据集进行处理""" self.data.drop(['企业名称', '分类', '企业性质_2'], axis=1, inplace=True) self.data.columns = ['industry', 'iso_num', 'qiye_nature', 'LC_performance', 'shareholder_num', 'shareholder_com', 'zl_zscq', 'ywfb', 'clsj', 'zczj', 'sjzj', 'person_num', 'difference'] self.data['target'] = data['difference'].apply(self.transform_target) self.data.drop('difference', axis=1) le = LabelEncoder() data['target'] = le.fit_transform(data['target']) modeling_data = data.join(pd.get_dummies(data['shareholder_com'])).join(pd.get_dummies(data['industry'])).join( pd.get_dummies(data['qiye_nature'])).drop(['industry', 'qiye_nature', 'shareholder_com'] , axis=1) target_data = modeling_data['target'] modeling_data.drop('target', axis=1, inplace=True) modeling_data = modeling_data.join(target_data) self.data = modeling_data self.X = self.data.iloc[:, :-1] self.y = self.data.iloc[:, -1] def split_data(self): return train_test_split(self.X , self.y , random_state=120 , test_size=0.3) def create_data(self): breast = load_breast_cancer() self.feature_names = breast.feature_names self.data = pd.DataFrame(breast.data, columns=self.feature_names) self.data['target'] = pd.Series(breast.target) self.X = breast.data self.y = breast.target def modelling(self, X_train, X_test, y_train, y_test): rf_clf = RandomForestClassifier(n_estimators=self.n_estimators , criterion=self.criterion) rf_clf.fit(X_train, y_train) acc = rf_clf.score(X_test, y_test) print(f"accuracy:{acc}") y_test_pred = rf_clf.predict(X_test) report_result = classification_report(y_test, y_test_pred) print(report_result) try: fig, axes = plt.subplots(2, 2) fig.set_size_inches(20, 12) fig.set_dpi(300) plot_roc_curve(rf_clf, X_test, y_test, ax=axes[0, 0]) axes[0, 0].set_title('roc 曲线') plot_precision_recall_curve(rf_clf, X_test, y_test, ax=axes[0, 1]) axes[0, 1].set_title('P_R 曲线') plot_confusion_matrix(rf_clf, X_test, y_test, ax=axes[1, 0]) axes[1, 0].set_title('混淆矩阵') plot_det_curve(rf_clf, X_test, y_test, ax=axes[1, 1]) axes[1, 1].set_title('det 曲线') plt.show() except: pass def run(self): if self.data is None: self.create_data() else: self.data_preprocessing() X_train, X_test, y_train, y_test = self.split_data() self.modelling(X_train, X_test, y_train, y_test) if __name__ == '__main__': parameters = get_parser() data = pd.read_excel(os.path.join(process_path, 'need/疫情数据集.xlsx')) rf = RF(*parameters, data) rf.run()
9. 封装
-
通用类
增加参数在
initialize
函数中增加import argparse import logging class BaseOptions: """BaseOptions: 定义参数使用 Examples ------------------------------------------------------------------------------------------------------------------- # 获取所有参数 >>> from options import BaseOptions >>> args = BaseOptions().parse() >>> # 取参数 >>> args.参数名 >>> # 新增参数 >>> args.新参数名=参数值 >>> # 修改参数 >>> args.已有参数=修改的参数值 >>> """ def __init__(self): """__init__: 初始化类,该类尚未初始化 """ self.initialized = False def initialize(self,parser): """initialize: 定义通用的参数 Parameters: ------------------------------------------------------------------------------------------------------------------- >>> parser.add_argument('--dataroot' # 变量名 , type=str # 读取参数的类型 , required=False # 可选项是否必须有 , default='a' # 默认值,类型与type统一 , choices=['a','b'] # 参数值只能从几个选项里面选择 , help='11111' # 该参数说明 ) ------------------------------------------------------------------------------------------------------------------- """ parser.add_argument('--dataroot' # 变量名 , type=str # 读取参数的类型 , required=False # 可选项是否必须有 , default='a' # 默认值,类型与type统一 , choices=['a','b'] # 参数值只能从几个选项里面选择 , help='11111' # 该参数说明 ) self.initialized = True return parser def gather_options(self): """gather_options: 初始化通用的参数(只初始化一次) """ if not self.initialized: parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) parser = self.initialize(parser) return parser.parse_args() def parse(self): """parse: 解析参数 """ opt = self.gather_options() self.opt = opt return self.opt