2021-10-22

 2021SC@SDUSC

#鉴于上次对train.py看的云里雾里的情况,决定重新从main开始进行逐行分析

main.py

import logging

logging模块是Python内置的标准模块,主要用于输出运行日志,可以设置输出日志的等级、日志保存路径、日志文件回滚等;相比print,具备如下优点:

  1. 可以通过设置不同的日志等级,在release版本中只输出重要信息,而不必显示大量的调试信息;
  2. print将所有信息都输出到标准输出中,严重影响开发者从标准输出中查看其它数据;logging则可以由开发者决定将信息输出到什么地方,以及怎么输出;

import warnings
Python 通过调用 warnings 模块中定义的 warn() 函数来发出警告。警告消息通常用于提示用户一些错误或者过时的用法,当这些情况发生时我们不希望抛出异常或者直接退出程序。警告消息通常写入 sys.stderr,对警告的处理方式可以灵活的更改,例如忽略或者转变为为异常。警告的处理可以根据警告类别,警告消息的文本和发出警告消息的源位置而变化。对相同源位置的特定警告的重复通常被抑制

在写程序时可以使用warning模块对警告处理或直接屏蔽。
 

import argparse

argparse 是 Python 内置的一个用于命令项选项与参数解析的模块,通过在程序中定义好我们需要的参数,argparse 将会从 sys.argv 中解析出这些参数,并自动生成帮助和使用信息。当然,Python 也有第三方的库可用于命令行解析,而且功能也更加强大,比如 docopt,Click。

from ACP_CSQA.model_params import *

from train import GraphC_QAModel

从train中获取模型

warnings.filterwarnings(action='ignore')

filterwarnings

warning.fillterwarning函数:

过滤警告,在 警告过滤器规则 列表中插入一个条目。默认情况下,条目插入在前面;如果 append 为真,则在末尾插入。它检查参数的类型,编译 message 和 module 的正则表达式,并将它们作为警告过滤器列表中的元组插入。如果多个地方都匹配特定的警告,那么更靠近列表前面的条目会覆盖列表中后面的条目,省略的参数默认为匹配一切的值

Eg:warnings.filterwarnings(action, message='', category=Warning, module='',

                        lineno=0, append=False)

logger = logging.getLogger('gct_dual')

初始化logger对象

日志等级分别有以下几种:

CRITICAL : 'CRITICAL',
ERROR : 'ERROR',
WARNING : 'WARNING',
INFO : 'INFO',
DEBUG : 'DEBUG',
NOTSET : 'NOTSET',

PARAMS_MAP = {

    'ACB_dual':AMR_CN_BERT_PARAMS,

}

来自ACP_CSQA.model_params

def train_model(args, local_rank):

转train.py

未完----

    args = PARAMS_MAP[args.encoder_type]

    if 'google' in args['lm_model']:

        lm_args = args['lm_model'][args['lm_model'].index('/')+1:]

    else:

        lm_args = args['lm_model']

    args['prefix'] = str(args['encoder_type']) + '_' + args['task'] + '_' + args['feature'] + '_lr' + str(

        args['lr']) + '_' + str(args['batch_multiplier']) + '_' + lm_args

    assert len(args['cnn_filters']) % 2 == 0

    model = GraphC_QAModel(args, local_rank)

    model.train()

def evaluate_model(args, local_rank):

    eval_file = args.eval_file

    gpus = args.gpus

    args = PARAMS_MAP[args.encoder_type]

    if 'google' in args['lm_model']:

        lm_args = args['lm_model'][args['lm_model'].index('/')+1:]

    else:

        lm_args = args['lm_model']

    args['prefix'] = str(args['encoder_type']) + '_' + args['task'] + '_' + args['feature'] + '_lr' + str(

        args['lr']) + '_' + str(args['batch_multiplier']) + '_' + lm_args

    assert len(args['cnn_filters']) % 2 == 0

    model = GraphC_QAModel(args, local_rank)

    model.evaluate_model(eval_file, gpus)

def get_params():

    # Training settings

    parser = argparse.ArgumentParser(description='gct_dual')

    parser.add_argument('--mode', dest="mode", type=str, default='eval')

    parser.add_argument("--encoder_type", dest="encoder_type", type=str, default=None,

                            help="Model Name")

    parser.add_argument("--eval_file", dest='eval_file', type=str, default=None)

    parser.add_argument("--gpus", type=int, default=0)

    args, _ = parser.parse_known_args()

    return args

if __name__ == "__main__":

    args = get_params()

    if args.mode == 'train':

        train_model(args, 0)

    else:

        evaluate_model(args, 0)

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值