(自用)代码研读:TextCNN模型代码分析之run.py

总的数据流程:

1.数据集:分为数据和标签

2.处理后送入build_dataset构建,再经过迭代DatasetIterater处理得到批次数据,送入train中将每一个批次索引转化为词向量形式训练

!!送入train之前需要先构建词汇表以及对应的词向量嵌入矩阵,以便前向传播的时候将索引转化为向量处理!!

需要注意的是:数据的流动以及形式转换的整个过程

# coding: UTF-8
import time
import torch
import numpy as np
from train_eval import train, init_network
from importlib import import_module#动态模块导入工具:import_module 是 Python 标准库 importlib 模块中的一个函数,用于在运行时动态导入模块。与常规的 import 语句相比,import_module 允许你在程序运行时根据需要动态地导入模块。
import argparse#命令行参数解析工具。

parser = argparse.ArgumentParser(description='Chinese Text Classification')#定义命令行参数解析器:用于从命令行读取模型名称、词嵌入类型和分词方式
parser.add_argument('--model', type=str, required=True, help='choose a model: TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer')#--model:选择使用的模型,必须提供。
parser.add_argument('--embedding', default='pre_trained', type=str, help='random or pre_trained')#--embedding:选择词嵌入类型,默认为预训练词嵌入。
parser.add_argument('--word', default=False, type=bool, help='True for word, False for char')#--word:选择分词方式,默认为按字符分词。
args = parser.parse_args()


if __name__ == '__main__':#主程序入口:检查是否作为主程序运行,确保以下代码仅在直接运行脚本时执行。
    dataset = 'THUCNews'  # 数据集

    # 搜狗新闻:embedding_SougouNews.npz, 腾讯:embedding_Tencent.npz, 随机初始化:random
    embedding = 'embedding_SougouNews.npz'#设置词嵌入文件路径:根据命令行参数选择词嵌入类型。如果选择随机初始化,则设置 embedding 为 random。
    if args.embedding == 'random':
        embedding = 'random'
    model_name = args.model  # 'TextRCNN'  # TextCNN, TextRNN, FastText, TextRCNN, TextRNN_Att, DPCNN, Transformer  获取模型名称:从命令行参数中获取模型名称。
    if model_name == 'FastText':#根据模型名称导入相应的工具函数:如果是 FastText 模型,使用 utils_fasttext 中的工具函数,并强制将词嵌入设置为 random。否则,使用 utils 中的工具函数。
        from utils_fasttext import build_dataset, build_iterator, get_time_dif
        embedding = 'random'
    else:
        from utils import build_dataset, build_iterator, get_time_dif

    x = import_module('models.' + model_name)#动态导入模型模块:根据模型名称动态导入相应的模型模块。
    config = x.Config(dataset, embedding)#初始化模型配置:使用模型模块中的 Config 类初始化配置。
    np.random.seed(1)#设置随机种子:确保每次运行的结果一致,设置 NumPy 和 Torch 的随机种子,并使得 CuDNN 后端的计算是确定性的。
    torch.manual_seed(1)
    torch.cuda.manual_seed_all(1)
    torch.backends.cudnn.deterministic = True  # 保证每次结果一样

    start_time = time.time()#记录开始时间:用于计算数据加载时间。
    print("Loading data...")
    vocab, train_data, dev_data, test_data = build_dataset(config, args.word)#加载数据集:使用 build_dataset 函数加载词汇表、训练集、验证集和测试集。
    train_iter = build_iterator(train_data, config)#构建数据迭代器:使用 build_iterator 函数为训练集、验证集和测试集构建数据迭代器。
    dev_iter = build_iterator(dev_data, config)
    test_iter = build_iterator(test_data, config)
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)#计算并打印数据加载时间。

    # train:初始化模型并训练
    config.n_vocab = len(vocab)#设置词汇表大小:将词汇表大小设置到配置中。
    model = x.Model(config).to(config.device)#初始化模型:根据配置初始化模型,并将模型移动到指定设备(CPU 或 GPU)。
    if model_name != 'Transformer':#初始化模型权重:如果模型不是 Transformer,则使用 init_network 函数初始化模型权重。
        init_network(model)
    print(model.parameters)#打印模型参数:打印模型的参数信息。
    train(config, model, train_iter, dev_iter, test_iter)#训练模型:调用 train 函数进行模型训练。

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sparkling*

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值