pytorch的一条条

1.model.train()和model.eval()

pytorch中的model.train()和model.eval()

model.train() #使用BatchNormalization()和Dropout(),此举会修改网络中的参数
model.eval() #不使用BatchNormalization()和Dropout(),即用于验证和测试阶段网络的固化
2.torch.nn.BCELoss()和torch.nn.CrossEntropyLoss()

Pytorch详解BCELoss和BCEWithLogitsLoss

Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解

BCELoss() 适合于多标签二分类,需要手动对对预测结果进行了Sigmoid变换。
BCEWithLogitsLoss() 在BCELoss的基础上对预测结果进行了Sigmoid变换。
CrossEntropyLoss() 可直接用于单标签多分类问题,等于softmax-log-NLLLoss合并。
3.Dataset,DataLoader,distributed.DistributedSampler

class torch.utils.data

torch.utils.data.Dataset

1.表示Dataset的抽象类
1.必须重载__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引

torch.utils.data.distributed.DistributedSampler

Pytorch多机多卡训练

多机多卡情况下分布式训练数据的读取要解决的问题:不同的卡读取到的数据应该是不同的。
sampler 确保dataloader只会load到整个数据集的一个特定子集。
DistributedSampler为每一个子进程划分出一部分数据集,以避免不同进程之间数据重复。

torch.utils.data.DataLoader

pytorch之dataloader深入剖析

Pytorch技巧1:DataLoader的collate_fn参数

1.数据加载器。组合数据集和采样器,并在数据集上提供单线程或多线程迭代器(分别便于单核、或多核GPU调用)
2.参数: (dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0,
		collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
		
shuffle:设置为True时会在每个epoch重新打乱数据
sampler:定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数
num_workers:用多少个子进程加载数据。 0表示数据将在主进程中加载
pin_memory:如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存
drop_last:如果数据集大小不能被batch_size整除,则设置为True后可删除最后一个不完整的batch.反之,则最后一个batch将更小。

collate_fn:如何取样本,可以定义自己的函数来准确地实现想要的功能,对dataset的每一个batch重新组合
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值