功能初体验
import torch
import torch.utils.data as Data
if __name__ == '__main__':
torch.manual_seed(1) # reproducible
BATCH_SIZE = 5 # 批训练的数据个数
x = torch.linspace(11, 20, 10) # x data: tensor([11., 12., 13., 14., 15., 16., 17., 18., 19., 20.])
y = torch.linspace(1, 10, 10) # y data: tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
# print(x)
# print(y)
# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y) #[return:
# (tensor(x1),tensor(y1));
# (tensor(x2),tensor(y2));
# ......
# print(torch_dataset, list(torch_dataset))
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=False, # 要不要打乱数据 (打乱比较好)
num_workers=0, # 多线程来读数据
)
for epoch in range(3): # 训练所有!整套!数据 3 次
for step,(batch_x,batch_y) in enumerate(loader): # 每一步 loader 释放一小批数据用来学习
#return:
#(tensor(x1,x2,x3,x4,x5),tensor(y1,y2,y3,y4,y5))
#(tensor(x6,x7,x8,x9,x10),tensor(y6,y7,y8,y9,y10))
# 假设这里就是你训练的地方...
# 打出来一些数据
print('Epoch: ', epoch, '| Step:', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())
参数简介
上图为源码中dataloader中所有的可选参数。除了第一个dataset参数外,其他均为可选参数。
- Dataset:处理好的所有数据
- batch_size:批数量
- shuffle:打乱数据
- sampler:采样机制,即从数据集里面取样本的方式(迭代器,每次返回一个样本)
- batch_sampler:把sampler的采样的样本根据batch_size组织成一个batch返回
- num_worker:加载数据的线程数
- collate_fn:把batch_sampler返回的list结构的一个batch的样本打包成一个tensor的结构
- pin_memory:将加载的数据拷贝到CUDA中的固定内存中,从而使数据更快地传输到支持cuda的gpu
- drop_last:丢弃余数
- timeout:如果是正数,表明等待从加载一个batch等待的时间,若超出设定的时间还没有加载完,就放弃这个batch,如果是0,表示不设置限制时间。默认为0
- worker_init_fn:如果不是None ,它将在每个worker子进程上以worker id ([0, num_workers - 1] )作为输入调用,在seeding之后和数据加载之前。
- generater:如果不是None,这个RNG将被RandomSampler用来生成随机索引,并被multiprocessing用来为worker生成’ base_seed ‘。 (默认值:’ ‘没有’ ')
- prefetch_factor:提前加载多少个batch的数据,可以保证线程不会等待,每个线程都总有至少一个数据在加载。提升显卡利用率。
- persistent_workers:如果为True,数据加载器将不会在数据集运行完一个Epoch后关闭worker进程。这允许维护worker数据集实例保持激活。(默认值:False),意思是运行完一个Epoch后并不会关闭worker进程,而是保持现有的worker进程继续进行下一个Epoch的数据加载。好处是Epoch之间不必重复关闭启动worker进程,加快训练速度。
Dataloader参数之间的互斥
值得注意的是,Dataloader的参数之间存在互斥的情况,主要针对自己定义的采样器:
- sampler:如果自行指定了sampler参数,则shuffle必须保持默认值,即False
- batch_sampler:如果自行指定了batch_sampler参数,则 batch_size、shuffle、sampler、drop_last 都必须保持默认值
- 如果没有指定自己是采样器,那么默认的情况下(即sampler和batch_sampler均为None的情况下),dataloader的采样策略是如何的呢:
sampler:
- shuffle = True:sampler采用 RandomSampler,即随机采样
- shuffle = Flase:sampler采用 SequentialSampler,即按照顺序采样
- batch_sampler:采用 BatchSampler,即根据 batch_size 进行batch采样
- 上面提到的 RandomSampler、SequentialSampler和BatchSampler都是PyTorch自己实现的,且它们都是Sampler的子类。
Sampler
SequentialSampler
SequentialSampler就是一个按照顺序进行采样的采样器,接收一个数据集做参数(实际上任何可迭代对象都可),按照顺序对其进行采样:
from torch.utils.data import SequentialSampler
pseudo_dataset = list(range(10, 20))
for data in SequentialSampler(pseudo_dataset):
print(data, end=" ")
0 1 2 3 4 5 6 7 8 9
RandomSampler
RandomSampler 即一个随机采样器,返回随机采样的值,第一个参数依然是一个数据集(或可迭代对象)。还有一组参数如下:
- replacement:bool值,默认是False,设置为True时表示可以采出重复的样本
- num_samples:只有在replacement设置为True的时候才能设置此参数,表示要采出样本的个数,默认为数据集的总长度。有时候由于replacement置True的原因导致重复数据被采样,导致有些数据被采不到,所以往往会设置一个比较大的值
from torch.utils.data import RandomSampler
pseudo_dataset = list(range(10, 20))
randomSampler1 = RandomSampler(pseudo_dataset)
randomSampler2 = RandomSampler(pseudo_dataset, replacement=True, num_samples=20)
print("for random sampler #1: ")
for data in randomSampler1:
print(data, end=" ")
print("\n\nfor random sampler #2: ")
for data in randomSampler2:
print(data, end=" ")
for random sampler #1:
4 5 2 9 3 0 6 8 7 1
for random sampler #2:
4 9 0 6 9 3 1 6 1 8 5 0 2 7 2 8 6 4 0 6
WeightedRandomSampler
WeightedRandomSampler和RandomSampler的参数一致,但是不在传入一个dataset,第一个参数变成了weights,只接收一个一定长度的list作为 weights 参数,表示采样的权重,采样时会根据权重随机从 list(range(len(weights))) 中采样,即WeightedRandomSampler并不需要传入样本集,而是只在一个根据weights长度创建的数组中采样,所以采样的结果可能需要进一步处理才能使用。weights的所有元素之和不需要为1。
from torch.utils.data import WeightedRandomSampler
weights = [1,1,10,10]
weightedRandomSampler = WeightedRandomSampler(weights, replacement=True, num_samples=20)
for data in weightedRandomSampler:
print(data, end=" ")
2 2 2 3 2 2 3 2 3 3 1 3 2 2 1 3 3 2 3 3
详细使用可参考: WeightedRandomSampler使用案例
BatchSampler
其他Sampler在每次迭代都只返回一个索引,而BatchSampler的作用是对上述这类返回一个索引的采样器进行包装,按照设定的batch size返回一组具体数据,因其他的参数和上述的有些不同:
- sampler:一个Sampler对象(或者一个可迭代对象)
- batch_size:batch的大小
- drop_last:是否丢弃最后一个可能不足batch size大小的数据
from torch.utils.data import BatchSampler
pseudo_dataset = list(range(10, 20))
batchSampler1 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=False)
batchSampler2 = BatchSampler(pseudo_dataset, batch_size=3, drop_last=True)
print("for batch sampler #1: ")
for data in batchSampler1:
print(data, end=" ")
print("\n\nfor batch sampler #2: ")
for data in batchSampler2:
print(data, end=" ")
for batch sampler #1:
[10, 11, 12] [13, 14, 15] [16, 17, 18] [19]
for batch sampler #2:
[10, 11, 12] [13, 14, 15] [16, 17, 18]
SubsetRandomSampler
SubsetRandomSampler 可以设置子集的随机采样,多用于将数据集分成多个集合,比如训练集和验证集的时候使用:
from torch.utils.data import SubsetRandomSampler
pseudo_dataset = list(range(10, 20))
subRandomSampler1 = SubsetRandomSampler(pseudo_dataset[:7])
subRandomSampler2 = SubsetRandomSampler(pseudo_dataset[7:])
print("for subset random sampler #1: ")
for data in subRandomSampler1:
print(data, end=" ")
print("\n\nfor subset random sampler #2: ")
for data in subRandomSampler2:
print(data, end=" ")
for subset random sampler #1:
14 15 11 16 13 10 12
for subset random sampler #2:
17 19 18
参考:https://blog.csdn.net/qq_38962621/article/details/111146427