首先参阅这两个网址了解具体细节:
Arbitrary iterable support — PyTorch Lightning 2.0.7 documentation
combined_loader — PyTorch Lightning 2.0.7 documentation
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
print(torch.cuda.is_available())
dataset_a = [torch.tensor([0]),torch.tensor([1]),torch.tensor([2]),torch.tensor([3])]
dataset_b = [torch.tensor([4]),torch.tensor([5]),torch.tensor([6]),torch.tensor([7]),torch.tensor([8])]
iterables = {'a': DataLoader(dataset_a, batch_size=2),'b': DataLoader(dataset_b, batch_size=2)}
class LitModel(pl.LightningModule):
def __init__(self):
super().__init__()
def training_step(self, batch, batch_idx):
print("train:", batch)
loss = torch.randn(1)
return loss
def validation_step(self, batch, batch_idx):
print("valid:", batch)
def test_step(self, batch, batch_idx,dataloader_idx=0):
if dataloader_idx == 0:
print("test1:", batch)
if dataloader_idx == 1:
print("test2:", batch)
def predict_step(self, batch, batch_idx, dataloader_idx=0):
print("predict:", batch)
trainer = pl.Trainer(accelerator="cpu",devices=1)
model = LitModel()
trainer.test(model, dataloaders=iterables)
trainer.test(model, dataloaders=iterables["b"])
trainer.validate(model, dataloaders=iterables["b"])
trainer.predict(model, dataloaders=iterables)
运行结果简约版展示:
True
Testing DataLoader 0:
test1: tensor([[0],[1]])
Testing DataLoader 0:
test1: tensor([[2],[3]])
Testing DataLoader 1:
test2: tensor([[4],[5]])
Testing DataLoader 1:
test2: tensor([[6],[7]])
Testing DataLoader 1:
test2: tensor([[8]])
Testing DataLoader 0:
test1: tensor([[4],[5]])
Testing DataLoader 0:
test1: tensor([[6],[7]])
Testing DataLoader 0:
test1: tensor([[8]])
Validation DataLoader 0:
valid: tensor([[4],[5]])
Validation DataLoader 0:
valid: tensor([[6],[7]])
Validation DataLoader 0:
valid: tensor([[8]])
Predicting DataLoader 0:
predict: tensor([[0],[1]])
Predicting DataLoader 0:
predict: tensor([[2],[3]])
Predicting DataLoader 1:
predict: tensor([[4],[5]])
Predicting DataLoader 1:
predict: tensor([[6],[7]])
Predicting DataLoader 1:
predict: tensor([[8]])
注意当你把test_step改成下述代码时(即取消dataloader_idx的默认值):
def test_step(self, batch, batch_idx,dataloader_idx):
if dataloader_idx == 0:
print("test1:", batch)
if dataloader_idx == 1:
print("test2:", batch)
trainer.test(model, dataloaders=iterables) #该行可正常运行,打印信息
trainer.test(model, dataloaders=iterables["b"]) #代码运行到这里时报错
"""
错误信息为:
RuntimeError: You provided only a single `test_dataloader`,
but have included `dataloader_idx` in `LitModel.test_step()`.
Either remove the argument or give it a default value i.e. `dataloader_idx=0`.
"""
假设trainer.fit(model,train_dataloaders=iterables),那么training_step里面得到的batch为这种形式:
{'a': batch_from_loader_dataset_a, 'b': batch_from_loader_dataset_a}