torchnet 简单使用文档

最近复现prototypical net发现作者源代码使用了torchnet中的meter,还写了一个engine。所以我用一天时间看了下这个库。其实不用这个也是完全没有任何问题的。它只是方便复用的一个框架
torchnet是用于torch的代码复用和模块化编程
文档连接:https://tnt.readthedocs.io/en/latest/
主要包含4个部分:

  • Dataset: 各种不同方式处理数据。
  • Engine: 各种机器学习算法
  • Meter: 性能度量评估
  • Log:
    • 模块详细分为如下部分:
    • Datasets:
      • BatchDataset
      • ListDataset
      • ResampleDataset
      • ShuffleDataset
      • TensorDataset [new]
      • TransformDataset
    • Meters
      • APMeter
      • mAPMeter
      • AverageValueMeter
      • AUCMeter
      • ClassErrorMeter
      • ConfusionMeter
      • MovingAverageValueMeter
      • MSEMeter
      • TimeMeter
    • Engines
      • Engine
    • Logger
      • Logger
      • VisdomLogger
      • MeterLogger [new, easy to plot multi-meter via Visdom]

功能

主要用于可视化、数据处理和存取、日志管理
在这里插入图片描述
原本是基于 l u a − t o r c h lua-torch luatorch的一个库,迁移到python中来,变成 p y t o r c h pytorch pytorch的一部分

安装

  • 先保证pytorch已经安装,接着,
  • pip install torchnet
  • 或者: pip install git+https://github.com/pytorch/tnt.git@master

当前主要用的Dataset和Meter模块

Dataset部分

抽象类:classtorchnet.dataset.dataset.Dataset
传递的dataset是一个可迭代对象即可, 不过BatchDataset必须让每一个元素为一个dict。
产生batch形式的数据:
torchnet.dataset.BatchDataset(dataset, batchsize, perm=<function BatchDataset.>, merge=None, policy=‘include-last’, filter=<function BatchDataset.>)

参数解释

  • dataset(Dataset): 数据集,这个数据集每个元素必须是一个dict,(应该是用于multi-task,每个样本由多个类,这时候,存储为item = {‘data’:data , ‘class’:class}会比较方便
  • batchsize(int): batchsize。
  • perm (function, optional): 洗牌函数,数据随机打乱。
  • merge (function, optional): 控制产生batch的行为, 在transform.makebatch源码中使用这个函数. Default is None.它的作用是合并数据,直接从 Dataset 得到的一个 batch 是一个 dict 的列表,makebatch 默认行为是将 dict 中的 key 合并的数据合并,合并后使用 merge 进一步处理,默认行为(merge = None)就是将这个dict中的数据按照第一维度拼成一个Tensor(如果可以拼接的话)返回。应该是和torch.utils.data.dataloader中的collate_fn 如果需要pin_memory的话,需要把数据按锁页方式存储
  • policy (str, optional): 处理最后一个batch策略。
    • include-last 包含最后一个,无论剩几个
    • skip-last 最后一个小于 batchsize 的时候丢掉
    • divisible-only 数据不能整除batchsize时报错
  • filter (function, optional)
    • 在产生batch之前过滤, filter(sample) 返回 True则包含这个数据,False表示过滤掉,默认为True.

拼接datasets

torchnet.dataset.ConcatDataset(datasets)

参数

  • datasets(iterable)一个dataset列表。

产生List形式数据

torchnet.dataset.ListDataset(elem_list, load=<function ListDataset.>, path=None)

参数解释

  • elem_list(iterable/str):用于load数据的参数列表,(可以是文件名列表或数据本身等,根据load制定)
  • load (function, optional):一个load数据的函数,第i个样本由load(elem_list[i])得到。默认是是恒等映射 i.e, lambda x: x
  • path (str, optional) : Defaults to None. 表示数据的目录,如果这个被提供,则elem_list[i]在传给load时,会将这个作为前缀

返回一个采样数据集

torchnet.dataset.ResampleDataset(dataset, sampler=<function ResampleDataset.>, size=None)

  • dataset(Dataset)
  • sampler (function, optional):采样函数,返回的是下标,第idx样本由dataset[sampler(dataset, idx)]返回,默认是恒等映射。
  • size (int, optional): 目标数据的大小,默认与源来一样

均匀分布采样

torchnet.dataset.ShuffleDataset(dataset, size=None, replacement=False)

参数

  • dataset(Dataset):
  • size (int, optional): 目标数据大小,如果replacement为False且它大于原数据大小,则报错
  • replacement (bool, optional): 均匀分布放回抽样
    函数resample(seed=None): 对数据重新采样,默认不需要传一个随机的seed.

数据集分割

torchnet.dataset.SplitDataset(dataset, partitions, initial_partition=None)

参数解释

  • dataset(Dataset)
  • partitions (dict): 分割dict ,key 是分割的自定义名称,val是权重,(值在0和1之间)或者size大小指定每个部分的样本数。
  • initial_partition (str, optional): 初始化选择分割

函数

  • select(partition):partition是上面dict中key的一个,指定使用哪个部分
    partition(str)

方便把一个已经存在的内存数据变成标准的结构

torchnet.dataset.TensorDataset(data)

Dataset from a tensor or array or list or dict.

       data的形式

       tensor or numpy array

               idx`th sample is `data[idx]

       dict of tensors or numpy arrays

                  idx`th sample is `{k: v[idx] for k, v in data.items()}

        list of tensors or numpy arrays

                  idx`th sample is `[v[idx] for v in data]

得到一个变换数据集

torchnet.dataset.TransformDataset(dataset, transforms)

参数
  • dataset(Dataset)
  • transforms (function/dict): 一个函数(可以是compose的)或者dict(值是函数),用于样本的变换

torchnet.transform部分:

  • 主要用的是torchnet.transform.compose函数,将transform拼接在一起
  • 它接收一个transform列表,每个transform是一个函数,接收上一次的输出作为输入。和TransformDataset搭配使用。
  • 例如 TransformDataset(ListDataset(class_names), compose([transforms1,transform2]))

torchnet.engine

其将训练过程和测试过程进行包装,抽象成一个类,提供tran和test方法和一个hooks(这部分文档是问题的)
文档中描述的应该是,torch.tensor中的hook,原理一致,只不过tensor中的hook是在变量forward或者bachward的时候执行(两种hook).
hooks包括on_start, on_sample, on_forward, on_update, on_end_epoch, on_end,可以自己制定函数,在开始,load数据,forward,更新还有epoch结束以及训练结束时执行。一般是用开查看和保存模型训练过程的一些结果。

torch.logger

用于记录以下评估结果和可视化(用visdom)

torch.Meter

classtorchnet.meter.APMeter

计算每个类的平均准确率AP

计算每个类平均准确率AP

方法

  • add(output, target, weight=None)

  • output (Tensor) – NxK tensor表示N个样本,分别属于K个类的概率

  • target (Tensor) – binary NxK tensort 表示样本是否属于某个类 (eg: a row [0, 1, 0, 1]意味着样本属于classes 2 and 4)

  • weight (optional, Tensor) – Nx1 tensor样本权重(weight>0)
    reset()

  • value() 返回一个1xK FloatTensor 表示每个类的AP

classtorchnet.meter.mAPMeter

函数和参数跟上面一样,返回的是mAq
classtorchnet.meter.ClassErrorMeter(topk=[1], accuracy=False)

维护一个混淆矩阵conf,大小为 k ∗ k k*k kk

每行表示真实类别被和其他类的混淆值。
classtorchnet.meter.ConfusionMeter(k, normalized=False)

  • Parameters
    • K(int):类别数
    • normalized (boolean) – 混淆矩阵归一化(行归一化)
  • 方法
    • add(pedicted,target)
    • Parameters
      • predicted (tensor) – N x K tensor 或者一个 N-tensor (值0到k-1),为predictor的输出
      • target (tensor) – N x K tensor(one hot) 或者一个 N-tensor(值0到k-1) 为真实类别 value(): 返回混淆矩阵

回归损失meter

计算平均值

classtorchnet.meter.AverageValueMeter

方法:

  • add(self,value,n = 1)
  • Value是记录值,n是记录次数。
    reset()
    Value(): 返回平均值和标准差

计算AUC

classtorchnet.meter.AUCMeter

  • add(output, target):
  • reset():
  • value():返回的是 (area, tpr, fpr)

计算MovingAverage

classtorchnet.meter.MovingAverageValueMeter(windowsize)

  • windowsize:窗口大小
  • add(value): 记录value
  • reset()
  • value() : 返回MA和标准差

classtorchnet.meter.MSEMeter(root=False)

总结

慢慢的将各种函数模块都给其研究一波,慢慢的将其全部都搞定,研究透彻,全部都将其搞定。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

big_matster

您的鼓励,是给予我最大的动力!

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值