文本分类(二) | (2) 程序入口

完整项目

run.py是整个项目的入口,它包含两部分,一是使用argparse工具,配置相关参数;二是整个项目的流程框架,各个模块/函数的调用。

目录

1. 参数配置

2. 项目流程


1. 参数配置

#声明argparse对象 可附加说明
parser = argparse.ArgumentParser(description='Chinese Text Classification')

#添加参数
#模型是必须设置的参数(required=True) 类型是字符串
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')
#embedding随机初始化或使用预训练词或字向量 默认使用预训练
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')
#基于词还是基于字 默认基于字
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')

#解析参数
args = parser.parse_args()

2. 项目流程

if __name__ == '__main__':
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz' #默认使用搜狗预训练字向量
    if args.embedding == 'random': #如果embedding参数设置为random
        embedding = 'random'
    
    #获取选择的模型名字
    model_name = args.model  # 'TextRCNN'  # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer
    
    #导入数据预处理与加载函数
    if model_name == 'FastText': #如果所选模型名字为FastText 由于增加了bi-gram tri-gram特征 会有不同的行为
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random' #此时embedding需要设置为随机初始化
    else: #其他模型统一处理
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name) #根据所选模型名字在models包下 获取相应模块(.py)
    config = x.Config(dataset, embedding) #每一个模块(.py)中都有一个模型定义类 和与该模型相关的配置类(定义该模型的超参数) 初始化配置类的对象
    
    #设置随机种子 确保每次运行的条件(模型参数初始化、数据集的切分或打乱等)是一样的
    np.random.seed(1)
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  


    start_time = time.time()
    print("Loading data...")
    #数据预处理
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word) #构建词典、训练集、验证集、测试集
    #构建训练集、验证集、测试集迭代器/生成器(节约内存、避免溢出)
    train_iter = build_iterator(train_data, config)
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)

    # 构造模型对象
    config.n_vocab = len(vocab) #词典大小可能不确定,在运行时赋值
    model = x.Model(config).to(config.device) #构建模型对象 并to_device
    
    if model_name != 'Transformer': #如果不是Transformer模型 则使用自定义的参数初始化方式
        init_network(model)    #也可以采用之前达观杯中的做法 把自定义模型参数的函数 放在模型的定义类中 在__init__中执行
    print(model.parameters)
    
    #训练、验证和测试
    train(config, model, train_iter, dev_iter, test_iter) 
  • 自定义参数初始化方式
# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
    for name, w in model.named_parameters():
        if exclude not in name: #排除embedding层
            #权重按指定方式初始化 默认xavier
            if 'weight' in name:
                if method == 'xavier':
                    nn.init.xavier_normal_(w)
                elif method == 'kaiming':
                    nn.init.kaiming_normal_(w)
                else:
                    nn.init.normal_(w)
            #偏置初始化为常数0
            elif 'bias' in name:
                nn.init.constant_(w, 0)
            else:
                pass

 

 

 

 

 

  • 10
    点赞
  • 36
    收藏
    觉得还不错? 一键收藏
  • 7
    评论
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值