【Python基础】parse_args() | HfArgumentParser | 读入json参数文件

note

一、parse_args()的使用

parse_args()的使用:

  • 建立解析对象:parser = argparse.ArgumentParser()
  • 给parser实例添加属性:parser.add_argument('-epoches', type=int, default=15, help='batch size for dataloader')
  • 增加属性;
  • 通过args = parser.parse_args()把刚才的属性从parserargs,后面直接通过args使用。
import argparse

if __name__ == "__main__":
    # 建立解析对象
    parser = argparse.ArgumentParser()
    
    # 给parser实例添加属性
    parser.add_argument('-gpu', action='store_true', default=True, help='use gpu or not')
    parser.add_argument('-bs', type=int, default=128, help='batch size for dataloader')
    parser.add_argument('-epoches', type=int, default=15, help='batch size for dataloader')
    
    # 把刚才的属性给args实例,后面就可以直接使用
    args = parser.parse_args()

    continous_feature_names = ['releaseYear', 'movieRatingCount', 'movieAvgRating', 'movieRatingStddev',
                               'userRatingCount', 'userAvgRating', 'userRatingStddev']
    categorial_feature_names = ['userGenre1', 'userGenre2', 'userGenre3', 'userGenre4', 'userGenre5',
                                'movieGenre1', 'movieGenre2', 'movieGenre3', 'userId', 'movieId']

    categorial_feature_vocabsize = [20] * 8 + [30001] + [1001]
    
    # build dataset for train and test
    batch_size = args.bs
    train_data = build_dataset(args.train_path)
    loader_train = DataLoader(train_data, batch_size=batch_size, num_workers=64, shuffle=True, pin_memory=True)
    test_data = build_dataset(args.test_path)
    loader_test = DataLoader(test_data, batch_size=batch_size, num_workers=64)

    device = torch.device("cuda" if args.gpu else "cpu")
    # train model
    model = WideDeep(categorial_feature_vocabsize, continous_feature_names, categorial_feature_names, embed_dim=64)
    if args.gpu:
        model = model.to(device)
    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-3)
    best_acc = 0

	# 这里就直接使用args.epoches
    for ep in range(args.epoches):
        train(ep)
        best_acc = test(ep, best_acc)

二、HfArgumentParser

这里也介绍下sys.argv是python的一个列表,包含命令行参数,如sys.argv[0]是脚本名称,sys.argv[1]是第一个参数,一次类推,像下面这种就是通过len(sys.argv)和第一个参数是否为json文件后缀来进行程序判断。
另:nlp中经常可以用到huggingface的HfArgumentParser的实例对象parser,通过model_args, data_args, training_args = parser.parse_args_into_dataclasses()自动将所有参数分到三个变量。

import parser
import subprocess
from arguments import ModelArguments, DataTrainingArguments
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    DataCollatorForSeq2Seq,
    HfArgumentParser,
    Seq2SeqTrainingArguments,
    set_seed,
)
import os, sys

parser = HfArgumentParser(
    (ModelArguments, DataTrainingArguments, Seq2SeqTrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
# cmd = """./parser.sh"""
# # 执行命令,并捕获命令行输出和返回值
# result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
    # If we pass only one argument to the script and it's the path to a json file,
    # let's parse it to get our arguments.
    model_args, data_args, training_args = parser.parse_json_file(
        json_file=os.path.abspath(sys.argv[1]))
else:
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
   
# 方法二
json_path = "./parser.json"
model_args, data_args, training_args = parser.parse_json_file(json_file=json_path)

# 输出命令行输出和返回值
# print("Command line output:")
# print(result.stdout.decode())
# print(result.stderr.decode())
# print("Return code:", result.returncode)

三、读入json文件参数

# 方法一:
def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--pt_data_path", type=str, required=True)
    parser.add_argument("--json_data_path", type=str, required=True)
    parser.add_argument("--json_save_path", type=str, required=True)
    parser.add_argument("--sent_type", type=int, default=0)
    parser.add_argument("--ppl_type", type=int, default=0)
    parser.add_argument("--cluster_method", type=str, default='kmeans')
    parser.add_argument("--reduce_method", type=str, default='tsne')
    parser.add_argument("--sample_num", type=int, default=10)
    parser.add_argument("--kmeans_num_clusters", type=int, default=100)
    parser.add_argument("--low_th", type=int, default=1)
    parser.add_argument("--up_th", type=int, default=99)

    args = parser.parse_args()
    return args
    
def main():
    args = parse_args()

# 方法二: 用常规方法读入json
with open("cluster.json", "r") as f:
    cluster_json = f.read()
params = json.loads(cluster_json)
# 4. 创建自定义参数解析器
parser = argparse.ArgumentParser()
# 5. 遍历参数字典,并添加到自定义解析器中
for param_name, param_value in params.items():
    # 使用 add_argument() 方法将每个参数添加到解析器中
    parser.add_argument(f'--{param_name}', type=type(param_value), default=param_value)
# 6. 解析命令行参数
args = parser.parse_args()
print(args)

'''
{
    "pt_data_path": "/Users/guomiansheng/Desktop/LLM/Cherry_LLM/alpaca_data_cherry.pt",
    "json_data_path": "../data/alpaca_data.json",
    "json_save_path": "../alpaca_data_pre.json",
    "sample_num": 10,
    "kmeans_num_clusters": 100,
    "low_pt": 25,
    "up_th": 75
}
'''

Reference

[1] Py官网:argparse — 命令行选项、参数和子命令解析器

  • 14
    点赞
  • 46
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

山顶夕景

小哥哥给我买个零食可好

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值