我们写好一个python模块供他人在命令行下调用时,有时需要在命令行指定参数传入到模块,通过argparse包可以便捷地接收参数。
例如我用pytorch定义了一个神经网络模型,需要命令行运行main.py时传入参数给初始化函数Net.init()
python main.py --task train --batch_size 10 --result_file ./result/deeptte.res --pooling_method attention --kernel_size 3 --alpha 0.1 --log_file run_log
import torch.nn as nn
class Net(nn.Module):
def __init__(self, kernel_size=3, num_filter=32, pooling_method='attention', num_final_fcs=3, final_fc_size=128,
alpha=0.3):
super(Net, self).__init__()
self.kernel_size = kernel_size
self.num_filter = num_filter
self.pooling_method = pooling_method
self.num_final_fcs = num_final_fcs
self.final_fc_size = final_fc_size
self.alpha = alpha
......
使用argparse首先通过ArgumentParser()
创建一个解析器parser,然后通过add_argument()
为解析器添加要解析的参数,最后通过parse_args()
就可以获得命令行传来的参数
import argparse
# 创建解析器
parser = argparse.ArgumentParser()
# 添加参数
parser.add_argument('--task', type=str)
parser.add_argument('--batch_size', type=int, default=64)
parser.add_argument('--epochs', type=int, default=100)
parser.add_argument('--weight_file', type=str)
parser.add_argument('--result_file', type=str)
parser.add_argument('--kernel_size', type=int)
parser.add_argument('--pooling_method', type=str)
parser.add_argument('--alpha', type=float)
parser.add_argument('--log_file', type=str)
# 解析参数
args = parser.parse_args()
接下来需要对参数进行过滤并利用参数来初始化神经网络。我们获取的参数args有的可能并不是Net.init()所需要的,因此通过inspect.getargspec()
方法来获取Net.init()方法的参数列表model_args,然后将命令行获得的参数转化为dict并进行遍历,如果参数不在model_args中,则pop掉。最后利用剩下的参数去初始化Net
import inspect
model_args = inspect.getargspec(model_class.__init__).args # 初始化函数所需参数
shell_args = args._get_kwargs() # 命令行输入的参数
# 对参数进行过滤
kwargs = dict(shell_args)
for arg, val in shell_args:
if arg not in model_args:
kwargs.pop(arg)
# 使用参数初始化神经网络对象
model = Net(**kwargs)