Torch使两个Dataloader保持一致的打乱顺序

背景

PYG中,数据类无法自定义,有时候要汇聚多组图数据,此时要求Dateloader的打乱顺序一致使得他们输出相同。

方法

使用相同的随机数种子对数据进行打乱。

代码

from torch_geometric.loader import DataLoader
import random
seed=random.randint(0,99999999)
print(seed)
torch.manual_seed(seed)
train_data=train_data.shuffle()
torch.manual_seed(seed)
seq_train_data=seq_train_data.shuffle()
# 创建dataloader1对象
dataloader1 = DataLoader(train_data, batch_size=32, shuffle=False)
# 使用相同的随机种子创建dataloader2对象
dataloader2 = DataLoader(seq_train_data,batch_size=32, shuffle=False)
for data in zip(dataloader1,dataloader2):
    print(data[0].y)
    print(data[1].y)
    break

在这里插入图片描述

在使用`torch.utils.data.DataLoader`时,可以通过`dataset`参数来指定数据集。如果你想自定义数据集,可以按照以下步骤: 1. 首先,你需要创建一个新的类,继承自`torch.utils.data.Dataset`,并实现它的两个方法:`__len__`和`__getitem__`。其中,`__len__`方法应该返回数据集中样本的数量,`__getitem__`方法应该返回指定索引的样本。例如: ```python import torch from torch.utils.data import Dataset class MyDataset(Dataset): def __init__(self, data, targets): self.data = data self.targets = targets def __len__(self): return len(self.data) def __getitem__(self, index): x = self.data[index] y = self.targets[index] return x, y ``` 在这个例子中,我们自定义了一个数据集类`MyDataset`,它有两个属性`data`和`targets`,分别代表数据和标签。`__len__`方法返回数据集的长度,`__getitem__`方法返回指定索引的样本。 2. 创建数据集对象。在实际使用中,你需要将数据和标签传入`MyDataset`类中,创建一个数据集对象,例如: ```python data = torch.randn(10, 3, 32, 32) targets = torch.randint(0, 2, (10,)) dataset = MyDataset(data, targets) ``` 在这个例子中,我们使用`torch.randn`函数生成了一个形状为`(10, 3, 32, 32)`的张量`data`作为数据,使用`torch.randint`函数生成了一个形状为`(10,)`的张量`targets`作为标签,然后将它们传入`MyDataset`类中,创建了一个数据集对象`dataset`。 3. 创建`DataLoader`对象。最后,你可以创建一个`DataLoader`对象,将自定义的数据集作为参数传入,例如: ```python dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) ``` 在这个例子中,我们创建了一个`batch_size`为2,打乱顺序的`DataLoader`对象`dataloader`,并将自定义数据集`dataset`作为参数传入。 这样,你就可以使用自定义的数据集了。需要注意的是,如果你的自定义数据集非常大,可能需要考虑使用多进程来加速数据读取,可以设置`num_workers`参数来指定读取数据的进程数量。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值