data_index = np.arange(labels.shape[0])
test_mask = torch.tensor(np.in1d(data_index, test_index), dtype=torch.bool)
dev_mask = torch.tensor(np.in1d(data_index, dev_index), dtype=torch.bool)
train_mask = ~(dev_mask + test_mask)
test_mask = test_mask.reshape(1, -1)
dev_mask = dev_mask.reshape(1, -1)
train_mask = train_mask.reshape(1, -1)
#if data.train_mask is None:
if train_mask not in data:
data.train_mask = train_mask
data.val_mask = dev_mask
data.test_mask = test_mask
else:
data.train_mask = torch.cat((data.train_mask, train_mask), dim=0)
data.val_mask = torch.cat((data.val_mask, dev_mask), dim=0)
data.test_mask = torch.cat((data.test_mask, test_mask), dim=0)
将 if data.train_mask is None 改成 if train_mask not in data:
即可解决