9.pytorch lightning之数据模块LightningDataModule

使用LightningDataModule

LightningDataModule是一个可重用的类,内部封闭处理数据所需的步骤。

其中封装了pytorch处理数据的5个步骤:

  1. 准备数据:一般是数据集下载过程,如果是本地数据则不需要。
  2. 数据清理。
  3. 加载Dataset对象。
  4. 预处理:旋转,翻转,归一化等
  5. 将train_loader、val_loader、test_loader封装在一起。

使用LightningDataModule的好处是可以将数据处理整合在一起,并且很容易利用,创建后可以将其用于任何地方。

在原生pytorch代码中,针对train/val/test的不同阶段通常先创建多个dataset和dataloader,而在lightning 则将这些封装在一起,只需在fit()是传入单个ightningDataModule对象。

# regular PyTorch
test_data = MNIST(my_path, train=False, download=True)
predict_data = MNIST(my_path, train=False, download=True)
train_data = MNIST(my_path, train=True, download=True)
train_data, val_data = random_split(train_data, [55000, 5000])

train_loader = DataLoader(train_data, batch_size=32)
val_loader = DataLoader(val_data, batch_size=32)
test_loader = DataLoader(test_data, batch_size=32)
predict_loader = DataLoader(predict_data, batch_size=32)
# lightning 风格
class MNISTDataModule(pl.LightningDataModule):
    def __init__(self, data_dir: str = "./"):
        super().__init__()
        self.data_dir = data_dir
        self.transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

    # 下载数据
    def prepare_data(self):
        # download
        MNIST(self.data_dir, train=True, download=True)
        MNIST(self.data_dir, train=False, download=True)
    '''
    setup 方法创建Dataset对象,对不同数据集指定预处理方法
    '''
    def setup(self, stage: str):
        # Assign train/val datasets for use in dataloaders
        if stage == "fit":
            mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
            self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])

        # Assign test dataset for use in dataloader(s)
        if stage == "test":
            self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)

        if stage == "predict":
            self.mnist_predict = MNIST(self.data_dir, train=False, transform=self.transform)

    # 以下方法创建不同阶段的数据加载器
    def train_dataloader(self):
        return DataLoader(self.mnist_train, batch_size=32)

    def val_dataloader(self):
        return DataLoader(self.mnist_val, batch_size=32)

    def test_dataloader(self):
        return DataLoader(self.mnist_test, batch_size=32)

    def predict_dataloader(self):
        return DataLoader(self.mnist_predict, batch_size=32)
    # 
    mnist = MNISTDataModule(my_path)
    model = LitClassifier()

    trainer = Trainer()
    # fit 训练阶段给定model 和data。
    trainer.fit(model, mnist)

另外,val_dataloader,train_dataloader等方法也可以实现在LightningModule子类中,fit时只给定相应的model实例。

同时加载多个数据集

这在半监督场景尤其有用,只需先将多个loader 用list,dict对象包装在一起,再封装在CombinedLoader中,注意CombinedLoader要求版本>2.0

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值