发现了一个pytorch的高级训练库,github的地址为:
pytorch/ignitegithub.com主要亮点功能:
对于训练过程中的for循环,精简代码,提供度量,提前终止,保存模型,提供基于visdom和tensorBoardX的训练可视化。
基本概念
一. Engine
ignite框架最基本的概念,循环一定的次数,循环的过程为基于训练数据,更新模型的参数。也可以加上评估的过程,基于验证数据集,计算损失函数的值。示例代码:
def update_model(trainer, batch):
model.train()
optimizer.zero_grad()
x, y = prepare_batch(batch)
y_pred = model(x)
loss = loss_fn(y_pred, y)
loss.backward()
optimizer.step()
return loss.item()
trainer = Engine(update_model)
trainer.run(data, max_epochs=100)
二. Events and Handlers
为了在训练过程中和外界进行交互,引用事件的机制。
事件触发的时间点包括:
- engine 的开始,结束
- epoch 的开始,结束
- batch iteration的开始,结束
用户注册事件的处理函数Handler,处理函数Handler在框架的事件触发时,会被回调。注册有两种方式:
add_event_handler()
- on注解器
示例代码:
trainer = Engine(update_model)
trainer.add_event_handler(Events.STARTED, lambda engine: print("Start training"))
# or
@trainer.on(Events.STARTED)
def on_training_started(engine):
print("Another message of start training")
# attach handler with args, kwargs
mydata = [1, 2, 3, 4]
def on_training_ended(engine, data):
print("Training is ended. mydata={}".format(data))
trainer.add_event_handler(Events.COMPLETED, on_training_ended, mydata)
三。时间轴
如下是框架的时间轴,我们主要理解以下:
可以注册epoch结束时的处理函数,在此函数中可以进行在验证数据集上的验证过程,判断是否进行提早终止训练,或者更新学习率(想必每个知道深度神经网络的同学应该都知道动态学习率的概念)。
四。State
Engine
类中包含了 State的对象。State类主要包含
epoch,当前的轮数
max_epochs,训练的最大轮数
iteration,训练结束后的迭代次数
output,训练结束后,在Engine中定义的处理函数的输出
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item()
def on_iteration_completed(engine):
iteration = engine.state.iteration
epoch = engine.state.epoch
loss = engine.state.output
print("Epoch: {}, Iteration: {}, Loss: {}".format(epoch, iteration, loss))
trainer.add_event_handler(Events.ITERATION_COMPLETED, on_iteration_completed)
这个例子中,engine.state.output保存了损失函数的值。
engine.state.output保存的值是在Engine中定义的处理函数的输出,这个输出的类型是没有明确的,所以我们可以灵活使用。
在看一个例子:
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return loss.item(), y_pred, y
trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
epoch = engine.state.epoch
loss = engine.state.output[0]
print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))
accuracy = Accuracy(output_transform=lambda x: [x[1], x[2]])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
这个例子中,在Engine中定义的处理函数的输出,也就是update函数的返回值为一个tuple:loss,y_pred, y
对比,看这个例子:
def update(engine, batch):
x, y = batch
y_pred = model(inputs)
loss = loss_fn(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return {'loss': loss.item(),
'y_pred': y_pred,
'y': y}
trainer = Engine(update)
@trainer.on(Events.EPOCH_COMPLETED)
def print_loss(engine):
epoch = engine.state.epoch
loss = engine.state.output['loss']
print ('Epoch {epoch}: train_loss = {loss}'.format(epoch=epoch, loss=loss))
accuracy = Accuracy(output_transform=lambda x: [x['y_pred'], x['y']])
accuracy.attach(trainer, 'acc')
trainer.run(data, max_epochs=10)
在Engine中定义的处理函数的输出,也就是update函数的返回值为一个字典:loss,y_pred, y。所以其他地方访问engine.state.output中的数据时,需要按照字典的方式。