(自用)代码研读:TextCNN模型代码分析之train_eval.py

训练代码(将训练以及测试过程记录在tensorboardX)

train调用evaluate以及test

test调用evaluate

# coding: UTF-8
"""
导入必要的库:
numpy, torch, torch.nn, torch.nn.functional:用于处理张量和构建模型。
metrics:用于计算准确率和其他评估指标。
time:用于计算时间差异。
get_time_dif:自定义函数,用于计算时间差异。
SummaryWriter:用于记录训练日志和可视化数据。
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn import metrics
import time
from utils import get_time_dif
from tensorboardX import SummaryWriter

"""
函数定义:
init_network:用于初始化模型权重。
train:用于训练模型。
test:用于测试模型。
evaluate:用于在验证或测试数据集上评估模型。
"""


"""
功能:初始化模型的权重和偏置。
参数:
model:待初始化的模型。
method:初始化方法,默认为 'xavier'。
exclude:排除不初始化的层,默认为 'embedding'。
seed:随机种子,默认为123。
实现:遍历模型的所有参数,根据名称和类型进行初始化。

整体流程
遍历模型的所有参数:获取每个参数的名字和张量。
排除特定参数:如果参数名字中包含 exclude 字符串,则跳过该参数。
检查参数类型:根据参数名字判断是权重还是偏置。
初始化权重:根据指定的方法(Xavier、Kaiming 或普通正态分布)初始化权重。
初始化偏置:将偏置初始化为 0。
"""
# 权重初始化,默认xavier
def init_network(model, method='xavier', exclude='embedding', seed=123):
    for name, w in model.named_parameters():#遍历模型的所有参数:model.named_parameters() 返回一个生成器,生成模型中所有参数的名字和参数张量。name:参数的名字。w:参数张量。
        if exclude not in name:#排除特定参数:检查参数名字中是否包含 exclude 字符串。如果 exclude 不在参数名字中,则继续进行初始化。
            if 'weight' in name:#检查参数是否为权重:检查参数名字中是否包含 'weight' 字符串。如果是权重参数,则根据指定的方法进行初始化。
                if method == 'xavier':#如果 method 为 'xavier',则使用 Xavier 正态分布初始化权重。
                    nn.init.xavier_normal_(w)#nn.init.xavier_normal_(w):使用 Xavier 正态分布初始化参数 w。
                elif method == 'kaiming':#如果 method 为 'kaiming',则使用 Kaiming 正态分布初始化权重。
                    nn.init.kaiming_normal_(w)#nn.init.kaiming_normal_(w):使用 Kaiming 正态分布初始化参数 w。
                else:#如果 method 既不是 'xavier' 也不是 'kaiming',则使用普通正态分布初始化权重
                    nn.init.normal_(w)
            elif 'bias' in name:#如果参数名字中包含 'bias' 字符串,则将其初始化为常数 0。
                nn.init.constant_(w, 0)#nn.init.constant_(w, 0):将偏置参数 w 初始化为 0。
            else:#如果参数既不是权重也不是偏置,则不进行任何操作。
                pass


"""
功能:训练模型,并在训练过程中进行验证。
参数:
config:包含超参数和路径配置的对象。
model:待训练的模型。
train_iter:训练数据迭代器。
dev_iter:验证数据迭代器。
test_iter:测试数据迭代器。
实现:
记录开始时间。
初始化优化器。
进行多轮训练(按 num_epochs 迭代)。
每个 epoch 中,遍历 train_iter 进行训练,计算损失,反向传播,更新参数。
每 100 个批次,计算并输出训练集和验证集上的效果(损失和准确率),保存模型参数。
如果验证集上的损失在 require_improvement 个批次内没有下降,则提前停止训练。
训练结束后,调用 test 函数在测试集上评估模型。
"""
def train(config, model, train_iter, dev_iter, test_iter):
    start_time = time.time()
    model.train()#设置模型为训练模式:启用 Batch Normalization 和 Dropout。
    optimizer = torch.optim.Adam(model.parameters(), lr=config.learning_rate)#定义优化器:使用 Adam 优化器来更新模型参数,学习率从配置中获取。

    # 学习率指数衰减,每次epoch:学习率 = gamma * 学习率
    # scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)#这段代码定义了一个学习率调度器,用于在每个 epoch 后指数衰减学习率。
    total_batch = 0  # 记录进行到多少batch
    dev_best_loss = float('inf')#初始化验证集最佳损失:设置初始值为正无穷大,以便在训练过程中更新。
    last_improve = 0  # 记录上次验证集loss下降的batch数,用于判断是否需要早停
    flag = False  # 记录是否很久没有效果提升
    writer = SummaryWriter(log_dir=config.log_path + '/' + time.strftime('%m-%d_%H.%M', time.localtime()))#初始化 TensorBoard 日志记录器:用于记录训练过程中的各种指标,便于在 TensorBoard 中进行可视化。
    for epoch in range(config.num_epochs):#开始训练循环:遍历每个 epoch,并打印当前 epoch 的信息。
        print('Epoch [{}/{}]'.format(epoch + 1, config.num_epochs))
        # scheduler.step() # 学习率衰减:如果启用学习率调度器,这段代码将在每个 epoch 结束时调整学习率。
        for i, (trains, labels) in enumerate(train_iter):#遍历训练数据迭代器:每次迭代获取一个批次的训练数据和对应的标签
            outputs = model(trains)#前向传播:将训练数据输入模型,获取预测结果。
            model.zero_grad()#梯度清零:在反向传播之前清除所有优化过的梯度。
            loss = F.cross_entropy(outputs, labels)#计算损失:使用交叉熵损失函数计算预测结果与真实标签之间的损失。
            loss.backward()#反向传播:计算损失相对于模型参数的梯度。
            optimizer.step()#优化器更新:在训练过程中,优化器根据模型参数的梯度来调整模型的参数值,以使得损失函数最小化
            if total_batch % 100 == 0:#每 100 个批次打印一次训练和验证结果:
                # 每多少轮输出在训练集和验证集上的效果
                true = labels.data.cpu()#true:将标签数据移动到 CPU。
                """在深度学习训练过程中,数据和模型参数的计算通常会在 GPU 上进行,以利用其高效的并行计算能力。
                然而,某些操作(如准确率计算、打印输出等)并不需要在 GPU 上进行,这时可以将数据移动到 CPU 上处理。
                """
                predic = torch.max(outputs.data, 1)[1].cpu()#predic:获取预测结果中最大值的索引。
                train_acc = metrics.accuracy_score(true, predic)#train_acc:计算训练集准确率。
                dev_acc, dev_loss = evaluate(config, model, dev_iter)#在验证集上评估模型:计算验证集的准确率和损失。
                if dev_loss < dev_best_loss:#保存最佳模型:如果当前验证集损失小于之前的最佳损失,则更新最佳损失。保存当前模型状态。标记本次训练有改进,并更新上次改进的批次计数
                    dev_best_loss = dev_loss
                    torch.save(model.state_dict(), config.save_path)
                    improve = '*'
                    last_improve = total_batch
                else:#如果没有改进:设置 improve 标记为空。
                    improve = ''
                time_dif = get_time_dif(start_time)#计算训练时间:获取从开始到当前的时间差。
                msg = 'Iter: {0:>6},  Train Loss: {1:>5.2},  Train Acc: {2:>6.2%},  Val Loss: {3:>5.2},  Val Acc: {4:>6.2%},  Time: {5} {6}'
                print(msg.format(total_batch, loss.item(), train_acc, dev_loss, dev_acc, time_dif, improve))#打印训练和验证结果:输出当前批次的训练损失、训练准确率、验证损失、验证准确率和训练时间。
                writer.add_scalar("loss/train", loss.item(), total_batch)#记录指标到 TensorBoard:将当前批次的损失和准确率记录到 TensorBoard。
                writer.add_scalar("loss/dev", dev_loss, total_batch)
                writer.add_scalar("acc/train", train_acc, total_batch)
                writer.add_scalar("acc/dev", dev_acc, total_batch)
                model.train()#重新设置为训练模式:确保模型在进入下一个批次之前处于训练模式。
            total_batch += 1#更新批次计数:增加批次计数。
            if total_batch - last_improve > config.require_improvement:#早停检查:如果在 require_improvement 个批次内没有改进,打印提示信息并设置早停标志。
                # 验证集loss超过1000batch没下降,结束训练
                print("No optimization for a long time, auto-stopping...")
                flag = True
                break
        if flag:#早停:如果早停标志被设置,跳出 epoch 循环。
            break
    writer.close()#关闭 TensorBoard 日志记录器。
    test(config, model, test_iter)#测试模型:在测试集上评估模型性能。

"""
定义函数:函数名为 test,用于在测试数据集上评估模型。
参数:
config:配置对象,包含模型路径、设备等信息。
model:要评估的模型。
test_iter:测试数据迭代器。
"""
def test(config, model, test_iter):
    # test
    model.load_state_dict(torch.load(config.save_path))#加载模型参数:从配置的保存路径加载训练好的模型参数到模型中。
    model.eval()#设置评估模式:禁用 Batch Normalization 和 Dropout 层。模型在评估模式下会使用整个训练过程中的平均值和方差。
    start_time = time.time()#记录开始时间:用于计算评估所需的总时间。
    test_acc, test_loss, test_report, test_confusion = evaluate(config, model, test_iter, test=True)#调用 evaluate 函数:在测试数据集上评估模型,获取测试集上的准确率(test_acc)、损失(test_loss)、分类报告(test_report)和混淆矩阵(test_confusion)。test=True:表示这是在测试模式下运行,evaluate 函数会返回更多详细信息。
    msg = 'Test Loss: {0:>5.2},  Test Acc: {1:>6.2%}'#定义格式化字符串:msg 用于格式化输出测试损失和测试准确率。。
    print(msg.format(test_loss, test_acc))#print 打印测试结果:将测试损失和测试准确率格式化后打印出来
    print("Precision, Recall and F1-Score...")
    print(test_report)#打印分类报告:包含每个类的精确度(Precision)、召回率(Recall)和 F1 分数。
    print("Confusion Matrix...")
    print(test_confusion)#打印混淆矩阵:显示预测结果与真实结果之间的匹配情况。
    time_dif = get_time_dif(start_time)
    print("Time usage:", time_dif)#打印时间差:输出评估所花费的时间。

"""
定义函数:evaluate 函数用于在给定的数据迭代器上评估模型的性能。
参数:
config:配置对象,包含类别列表等信息。
model:要评估的模型。
data_iter:数据迭代器,用于提供测试或验证数据。
test:布尔值,表示是否在测试模式下运行,默认为 False。

评估模式:用于在验证集上评估模型性能,主要关注准确率和损失。
测试模式:用于在测试集上评估模型的最终性能,并生成详细的分类报告和混淆矩阵。
"""
def evaluate(config, model, data_iter, test=False):
    model.eval()#设置评估模式:禁用 Batch Normalization 和 Dropout 层,确保评估时的稳定性和一致性。
    loss_total = 0#loss_total:累计损失的变量,初始值为 0。
    predict_all = np.array([], dtype=int)#predict_all:存储所有预测标签的数组,初始为空。
    labels_all = np.array([], dtype=int)#labels_all:存储所有真实标签的数组,初始为空。
    with torch.no_grad():#禁用梯度计算:在评估阶段不需要计算梯度,以节省内存和计算资源。
        for texts, labels in data_iter:
            outputs = model(texts)
            loss = F.cross_entropy(outputs, labels)
            loss_total += loss
            labels = labels.data.cpu().numpy()
            predic = torch.max(outputs.data, 1)[1].cpu().numpy()
            labels_all = np.append(labels_all, labels)
            predict_all = np.append(predict_all, predic)
    """
    遍历 data_iter:逐批次加载数据。
    前向传播:将输入数据传递给模型,得到输出结果。
    计算损失:使用交叉熵损失函数计算当前批次的损失,并累计到 loss_total。
    移动到 CPU:将标签数据从 GPU 移动到 CPU,并转换为 NumPy 数组。
    预测标签:取出每个样本的最大概率对应的类别作为预测结果,并移动到 CPU。
    累加标签和预测结果:将当前批次的真实标签和预测标签添加到 labels_all 和 predict_all 中。
    """
    acc = metrics.accuracy_score(labels_all, predict_all)#计算准确率:使用 sklearn 的 accuracy_score 函数计算所有标签和预测结果的准确率
    if test:
        report = metrics.classification_report(labels_all, predict_all, target_names=config.class_list, digits=4)#生成分类报告:使用 classification_report 函数生成详细的分类报告,包括精确度(precision)、召回率(recall)和 F1 分数(f1-score)。
        confusion = metrics.confusion_matrix(labels_all, predict_all)#生成混淆矩阵:使用 confusion_matrix 函数生成混淆矩阵,显示预测结果和真实结果的匹配情况。
        return acc, loss_total / len(data_iter), report, confusion#返回值:在测试模式下,返回准确率、平均损失、分类报告和混淆矩阵。
    return acc, loss_total / len(data_iter)#返回值:在非测试模式下,返回准确率和平均损失。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

sparkling*

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值