1.model.train()和model.eval()
pytorch中的model.train()和model.eval()
model.train() #使用BatchNormalization()和Dropout(),此举会修改网络中的参数
model.eval() #不使用BatchNormalization()和Dropout(),即用于验证和测试阶段网络的固化
2.torch.nn.BCELoss()和torch.nn.CrossEntropyLoss()
Pytorch详解BCELoss和BCEWithLogitsLoss
Pytorch常用的交叉熵损失函数CrossEntropyLoss()详解
BCELoss() 适合于多标签二分类,需要手动对对预测结果进行了Sigmoid变换。
BCEWithLogitsLoss() 在BCELoss的基础上对预测结果进行了Sigmoid变换。
CrossEntropyLoss() 可直接用于单标签多分类问题,等于softmax-log-NLLLoss合并。
3.Dataset,DataLoader,distributed.DistributedSampler
torch.utils.data.Dataset
1.表示Dataset的抽象类
1.必须重载__len__和__getitem__,前者提供了数据集的大小,后者支持整数索引
torch.utils.data.distributed.DistributedSampler
多机多卡情况下分布式训练数据的读取要解决的问题:不同的卡读取到的数据应该是不同的。
sampler 确保dataloader只会load到整个数据集的一个特定子集。
DistributedSampler为每一个子进程划分出一部分数据集,以避免不同进程之间数据重复。
torch.utils.data.DataLoader
Pytorch技巧1:DataLoader的collate_fn参数
1.数据加载器。组合数据集和采样器,并在数据集上提供单线程或多线程迭代器(分别便于单核、或多核GPU调用)
2.参数: (dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0,
collate_fn=<function default_collate>, pin_memory=False, drop_last=False)
shuffle:设置为True时会在每个epoch重新打乱数据
sampler:定义从数据集中提取样本的策略。如果指定,则忽略shuffle参数
num_workers:用多少个子进程加载数据。 0表示数据将在主进程中加载
pin_memory:如果设置为True,那么data loader将会在返回它们之前,将tensors拷贝到CUDA中的固定内存
drop_last:如果数据集大小不能被batch_size整除,则设置为True后可删除最后一个不完整的batch.反之,则最后一个batch将更小。
collate_fn:如何取样本,可以定义自己的函数来准确地实现想要的功能,对dataset的每一个batch重新组合