torch简单网络必备操作

最近学习人工智能神经网络方面的东西,苦于CSDN上找的资料太碎了,因此整理一下一篇文章后续使用。

神经网络代码基本框架

导入数据,定义神经网络、训练、输出loss图、accuracy图,保存模型,读取模型,测试模型

1.导入数据

pytorch的数据读取
pytorch数据读取的核心是torch.utils.data.DataLoader类。即如果使用pytorch构建神经网络时,那么DNN中传入的数据类型就是DataLoader类。

在实际计算中,数据量很大,考虑到内存有限,且IO速度很慢,因此不能一次性的将其全部加载到内存中,也不能只用一个线程去加载。因而需要多线程、迭代加载, 因而专门定义加载器:DataLoader。
DataLoader 是一个可迭代对象, An Iterable Object, 内部配置了魔法函数——iter——,调用它将返回一个迭代器。

DataLoader(dataset, batch_size=1, shuffle=False, sampler=None,
           batch_sampler=None, num_workers=0, collate_fn=None,
           pin_memory=False, drop_last=False, timeout=0,
           worker_init_fn=None, *, prefetch_factor=2,
           persistent_workers=False)
# __init__参数包含两部分,前半部分用于指定数据集 + 采样器,后半部分为多线程参数。
  1. 数据集(dataset):从中加载数据的数据集。
  2. batch_size(int,可选):每个批次要加载多少个样本(默认值:1)。
  3. shuffle(bool,可选):设置为“True”在每个epoch以重新排列数据(默认值:False)。
  4. 采样器(采样器或可迭代,可选):定义要绘制的策略来自数据集的样本。可以是任何带有__len的``可迭代实施。如果指定,则不能指定参数shuffle。
  5. batch_sampler(sampler或Iterable,可选):like:attr:‘sampler’,但是一次返回一批索引。和一下参数相互排斥:attr:batch_size,:attr:shuffle,:attr:sampler,和:attr:‘drop_last’。
  6. num_workers(int,可选):用于数据的子进程数量装载0表示数据将在主进程中加载。(默认值:0
  7. collate_fn(可调用,可选):合并样本列表以形成小批量张量。在使用来自映射样式数据集。
  8. pin_memory(bool,可选):如果“True”,数据加载器将复制张量在返回它们之前,它被固定在CUDA的内存中。如果您的数据元素是自定义类型,或者您的:attr:collate_fn返回一个自定义类型的批处理,
  9. drop_last(bool,可选):设置为“True”以删除最后一个不完整的批,如果数据集大小不能被批大小整除。如果“False”,并且数据集的大小不能除以批大小,然后是最后一批将更小。(默认值:False
Dataset

pytorch支持两种类型的数据集map-style dataset和iterable-style dataset。
Map-style datasets:
字典型数据集是指实现了__getitem__()和__len__()协议,表示从索引到数据样本的映射。
可以继承抽象类torch.utils.data.DataSet,并重写__getitem__()和__len__()方法。
图片类型Dataset定义
Dataset是什么
Dataset在pytorch中是什么角色

import torch.utils.data as data
class Mydataset(data.Dataset):
    	
    def __init__(self):
        data=XXXX
        self.Data=torch.FLoatTensor(data)    
    def __getitem__(self, index):      
        return  self.data[index], self.target[index]
        ##返回你要提供给Dataloader的一个样本(数据+标签)

    def __len__(self):     
        return  len(self.data)
batch_size

提供每一笔数据时,数据的大小

shuffle (洗牌)

每个epoch中是否打乱样本顺序

sampler(取样器)

定义从数据集提取样本的策略。
对于IterableDataset来说,数据读取的顺利是由用户定义迭代决定的。回想下python的迭代器,只能通过循环调用next()方法,依次拿到下一个样本。不能改变原有的次序。

对may-style Dataset来说,sampler用来在数据读取时,指定样本索引的顺序。可以指定DataLoader的shuffle参数来指导顺序读取还是乱序读取。如果shuffle=True,会自动构造一个R

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值