Pytorch的ignite库

Pytorch的ignite库是一个high-level封装训练和测试代码的库,使用库里的对象和函数,我们就会更加简洁的写出训练和测试模型的代码,下面先给出具体的使用例子:

from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss

model = Net()
train_loader, val_loader = get_data_loaders(train_batch_size, val_batch_size)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.8)
criterion = nn.NLLLoss()

trainer = create_supervised_trainer(model, optimizer, criterion)

val_metrics = {
    "accuracy": Accuracy(),
    "nll": Loss(criterion)
}
evaluator = create_supervised_evaluator(model, metrics=val_metrics)

@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(trainer):
    print("Epoch[{}] Loss: {:.2f}".format(trainer.state.epoch, trainer.state.output))

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_results(trainer):
    evaluator.run(train_loader)
    metrics = evaluator.state.metrics
    print("Training Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))

@trainer.on(Events.EPOCH_COMPLETED)
def log_validation_results(trainer):
    evaluator.run(val_loader)
    metrics = evaluator.state.metrics
    print("Validation Results - Epoch: {}  Avg accuracy: {:.2f} Avg loss: {:.2f}"
          .format(trainer.state.epoch, metrics["accuracy"], metrics["nll"]))

trainer.run(train_loader, max_epochs=100)

(1)首先定义好基本元件:model,dataloader,optimizer,loss_criterion。

(2)使用create_supervised_trainer创建trainer engine,传入model, optimizer, criterion

(3)使用create_supervised_evaluator创建evaluator engine,传入model, metrics,其中metrics是一个字典来存储需要度量的指标。

(4)@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))是触发迭代log_interval (是个整数)step结束后,触发函数def log_training_loss(trainer):。这里代码段:

@trainer.on(Events.ITERATION_COMPLETED(every=log_interval))
def log_training_loss(engine):
    print("Epoch[{}] Loss: {:.2f}".format(engine.state.epoch, engine.state.output))

等价于

def log_training_loss(engine):
    print("Epoch[{}] Loss: {:.2f}".format(engine.state.epoch, engine.state.output))

trainer.add_event_handler(Events.ITERATION_COMPLETED, log_training_loss)

(5)最后我们使用trainer.run启动trainer engine来进行训练,epoch数为100。

重写create_supervised_trainer和create_supervised_evaluator

以上两种函数都可以重写且函数名可以修改,具有如下规则:

规则是在函数内部实现一个def _update(engine, batch),关键参数名engine和batch不要动,内容可以修改,最后返回Engine(_update),一个Engine对象传入_update方法。

def create_supervised_trainer(model, optimizer, loss_fn,
                              device=None):
    """
    Factory function for creating a trainer for supervised models

    Args:
        model (`torch.nn.Module`): the model to train
        optimizer (`torch.optim.Optimizer`): the optimizer to use
        loss_fn (torch.nn loss function): the loss function to use
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.

    Returns:
        Engine: a trainer engine with supervised update function
    """
    if device:
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model.to(device)

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        img, target = batch
        img = img.to(device) if torch.cuda.device_count() >= 1 else img
        target = target.to(device) if torch.cuda.device_count() >= 1 else target
        score, feat = model(img)
        loss = loss_fn(score, feat, target)
        loss.backward()
        optimizer.step()
        # compute acc
        acc = (score.max(1)[1] == target).float().mean()
        return loss.item(), acc.item()

    return Engine(_update)
  • create_supervised_evaluator

原始实现:https://pytorch.org/ignite/_modules/ignite/engine.html#create_supervised_evaluator

def create_supervised_evaluator(model, metrics,
                                device=None):
    """
    Factory function for creating an evaluator for supervised models

    Args:
        model (`torch.nn.Module`): the model to train
        metrics (dict of str - :class:`ignite.metrics.Metric`): a map of metric names to Metrics
        device (str, optional): device type specification (default: None).
            Applies to both model and batches.
    Returns:
        Engine: an evaluator engine with supervised inference function
    """
    if device:
        if torch.cuda.device_count() > 1:
            model = nn.DataParallel(model)
        model.to(device)

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            data, pids, camids = batch
            data = data.to(device) if torch.cuda.device_count() >= 1 else data
            feat = model(data)
            return feat, pids, camids

    engine = Engine(_inference)

    for name, metric in metrics.items():
        metric.attach(engine, name)

    return engine

规则是在函数内部实现一个def _inference(engine, batch),关键参数名engine和batch不要动,内容可以修改,最后返回Engine(_inference),一个Engine对象传入_inference方法。此外metric.attach()里的metric是ignite.metrics下的对象,会于evaluator绑定,在run时计算metric。

Event

Event类定义了几种变量用于指示事件。

class Events(Enum):
    EPOCH_STARTED = "epoch_started"               # 当一个新的 epoch 开始时会触发此事件
    EPOCH_COMPLETED = "epoch_completed"           # 当一个 epoch 结束时, 会触发此事件
    STARTED = "started"                           # 开始训练模型是, 会触发此事件
    COMPLETED = "completed"                       # 当训练结束时, 会触发此事件
    ITERATION_STARTED = "iteration_started"       # 当一个 iteration 开始时, 会触发此事件
    ITERATION_COMPLETED = "iteration_completed"   # 当一个 iteration 结束时, 会触发此事件
    EXCEPTION_RAISED = "exception_raised"         # 当有异常发生时, 会触发此事件
©️2020 CSDN 皮肤主题: 编程工作室 设计师:CSDN官方博客 返回首页