一个可以对基于Pytorch搭建的模型的训练过程进行全程追踪的模块

本文所述的trace模块实现了对损失和准确率的全过程跟踪,并在生成损失与准确率统计图的同时实现损失、准确率和模型本身的同步存储,这使得即使训练间断,训练过程中的数据仍然可以被保留和呈现。模块源文件可在本项目Github仓库中获取。下文是对该模块的说明,源自本项目的README文件,由于笔者考试周临近,暂时没时间写个中文版的介绍,先开个贴占个坑吧。

A Pytorch Training Process Tracing Module

Xiangnan Zhang

School of Future Technology, Beijing Institute of Technology


This module called "trace" is used to trace whole traning process, as well as realize visualization of accuracy and loss.

All loss and accuracy can be traced via a Statistic object. During training and testing, these data will be appended to its list-format attributes. The model and traning data will
be saved and loaded at the same time, hence to guarantee the whole-traning-process tracing. While saving the model and traning data, line charts of loss and accuracy will be shown and saved as follows.

image

Importing

In terms of importing this module, you should code as follows:

import trace
from trace import Statistic

NOTICE: The class Statistic() should be imported separately, because it is the preriquisit of function load_statistic()and sys_load().

Details

Statistic(path)

Objects that belong to this class stores traning and testing loss and accuracy. So when you are initializing your model, you should create a Statistic like this:

statis=Statistic(statis_path)

its __init__()method will establish attributes self.train_loss, self.train_accuracy, self.test_lossand self.test_accuracy. Each of them are empty list, so you can
use .append()method to append values in your train_loop and test_loop functions, like this:

def test_loop(test_ds,model,loss_fn,statis):
    model.eval()
    ...
    statis.test_loss.append(test_loss)
    statis.test_accuracy.append(accuracy)
    ...

There are two methods for Statistic project called self.draw()and self.save, which can be used to draw statistical images and save Statistic object as pkl files. However, in most cases you should use Sys object’s .sys_conclude()method instead.

load_statistic(path)

This function is used to load a Statistic object. But in most cases, you should use sys_load()function instead.

sys_load(model_path,statis_path)

This function is used to load both model and Statistic object. It is highly recommended that you use this function to load these two items, because it can guarentee that model and Statistic object can be loaded at the same time.

Sys(model,model_path,statis)

This class aims to process model and Statistic object at the same time. You should create a Sys project after the model and Statistic object are loaded or iinitialized, like this:

syst=trace.Sys(model,model_path,statis)

Then you can use .sys_conclude()method to save both the model and Statistic object:

syst.sys_conclude("ConvM(4_categ)")

A string that represents the model should be given when using this method. When you need to save your model and Statistic object, this method is always highly recommended.


After saving, the model will be saved as a pth file, and the Statistic object will be saved as a pkl file. These suffixes should be included into file paths.

Deficiency

When using this module to trace data, train_loss will dramatically increase at the begining of a new training process, which can be seen in the front image. However, I’m not sure whether it is the module’s problem, or it is my model’s problem.

  • 47
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

张向南zhangxn

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值