可迭代的DataLoader可以参考:pytorch构造可迭代的DataLoader,动态流式读取数据源,不担心内存炸裂(pytorch Data学习三)
一般DataLoader封装:
使用如下方法即可:
dataloader = DataLoader(dataset, batch_size=3)
# 使用方法
for data in dataloader:
print(data)
示例代码:
import numpy as np
import torch
import torch.utils.data as Data
def train_auto_encoder(my_data: np.ndarray):
train_tensor = torch.tensor(my_data).float()
batch_size = 3
train_loader = Data.DataLoader(
dataset=Data.TensorDataset(train_tensor), # 封装进Data.TensorDataset()类的数据,可以为任意维度
batch_size=batch_size, # 每块的大小
shuffle=True, # 要不要打乱数据 (打乱比较好)
num_workers=2, # 多进程(multiprocess)来读数据
drop_last=True, # 为了确保维度,如果最后一组数据不够batch的数量,则删除掉
)
for train_data in train_loader:
print(train_data)
def main():
data = np.random.random(size=(20, 5))
train_auto_encoder(data) # 拿正例数据进行训练
if __name__ == '__main__':
main()
就可以得到结果:
[tensor([[0.1168, 0.8438, 0.1179, 0.2627, 0.9513],
[0.3485, 0.6598, 0.9423, 0.5677, 0.6109],
[0.1255, 0.1181, 0.3377, 0.8533, 0.8340]])]
[tensor([[0.4643, 0.7502, 0.9944, 0.7076, 0.3739],
[0.2627, 0.2119, 0.1119, 0.7022, 0.5895],
[0.1801, 0.7596, 0.0995, 0.0758, 0.2582]])]
.........