wandb在pytorch lightning中的使用


在wandb和pytorch lightning的官方文档中都有 在pytorch lightning中使用wandb的使用方法,链接分别为 https://docs.wandb.ai/guides/integrations/lightning#log-gradients-parameter-histogram-and-model-topologyhttps://pytorch-lightning.readthedocs.io/en/latest/extensions/generated/pytorch_lightning.loggers.WandbLogger.html?highlight=wandblogger。本笔记先以官方文档中内容为主进行使用方法的讲解,然后结合一个简单的使用ResNet网络图像分类的例子进行解析,部分没有提到的内容可自行在上述官方文档查询学习

使用前提

使用之前,如笔记wandb在pytorch中的使用记录中记录,需要先安装wandb,注册账号后使用密钥链接账号

使用解析

初始化

使用方面,pytorch lightning框架中提供了WandbLogger接口,在完成wandb安装后,直接使用WandbLogger接口提供的各种方法就能进行各类数据的记录,与单独使用wandb功能方面一致。使用时,只需先初始化一个WandbLogger类的对象,然后在定义trainer时将其作为logger传入即可,如下代码所示

from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer

wandb_logger = WandbLogger()  # 初始化个WandbLogger对象
trainer = Trainer(logger=wandb_logger)  # 初始化trainer

进行个WandbLogger类实例化时有以下常用参数

参数描述
project定义要登录的 wandb 项目
name对当前wandb运行记录设置名称
log_model设置记录参数范围:如果 log_model=“all” 记录所有模型,如果 log_model=True 在训练结束时记录
save_dir保存数据的路径

模型超参数保存

初始化时将WandbLogger的实例化对象设置给trainer后,继承pl.LightningModule进行网络定义时,在__int__函数中直接调用self.save_hyperparameters()即可进行超参数保存

class LitModule(LightningModule):
    def __init__(self, *args, **kwarg):
        self.save_hyperparameters()

记录其他配置参数

# add one parameter
wandb_logger.experiment.config["key"] = value

# add multiple parameters
wandb_logger.experiment.config.update({key1: val1, key2: val2})

# use directly wandb module
wandb.config["key"] = value
wandb.config.update()

记录梯度、参数直方图和模型拓扑

如wandb中一样,调用WandbLogger.watch()方法进行设置,如代码案例main()函数所示,该方法可以进行如下几种设置

# log gradients and model topology
wandb_logger.watch(model)

# log gradients, parameter histogram and model topology
wandb_logger.watch(model, log="all")

# change log frequency of gradients and parameters (100 steps by default)
wandb_logger.watch(model, log_freq=500)

# do not log graph (in case of errors)
wandb_logger.watch(model, log_graph=False)

记录metric

可以通过在 LightningModule 中调用 self.log(‘my_metric_name’, metric_vale) 将指标记录到 W&B,例如在 代码案例training_step()validation_step() 方法中

记录metric的最小值/最大值

使用 wandb 的 define_metric 函数,可以定义是否希望 wandb 汇总指标显示该指标的最小值、最大值、平均值或最佳值。如果未使用 define_metric,那么最后记录的值将出现在您的摘要指标中。有关更多信息,请参阅此处的 define_metric 参考文档和此处的指南

class My_LitModule(LightningModule):
    ...
    
    def validation_step(self, batch, batch_idx):
        if trainer.global_step == 0: 
            wandb.define_metric('val_accuracy', summary='max')  # 显示val_accuracy指标的最大值
        
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return preds

记录图像、文本等

WandbLogger 具有用于记录媒体的 log_image、log_text 和 log_table 方法;也可以直接调用 wandb.log 或 trainer.logger.experiment.log 来记录其他媒体类型,例如音频、分子、点云、3D 对象等
注意:在trainer中使用 wandb.log 或 trainer.logger.experiment.log 时,请确保在被传递的字典中也包含“global_step”:trainer.global_step。这样,可以将当前记录的信息与通过其他方法记录的信息对齐。

记录图像

# using tensors, numpy arrays or PIL images
wandb_logger.log_image(key="samples", images=[img1, img2])

# adding captions
wandb_logger.log_image(key="samples", images=[img1, img2], caption=["tree", "person"])

# using file path
wandb_logger.log_image(key="samples", images=["img_1.jpg", "img_2.jpg"])

# using .log in the trainer
trainer.logger.experiment.log({
    "samples": [wandb.Image(img, caption=caption) 
    for (img, caption) in my_images]
})

记录文本

# data should be a list of lists
columns = ["input", "label", "prediction"]
my_data = [["cheese", "english", "english"], ["fromage", "french", "spanish"]]

# using columns and data
wandb_logger.log_text(key="my_samples", columns=columns, data=my_data)

# using a pandas DataFrame
wandb_logger.log_text(key="my_samples", dataframe=my_dataframe)

记录表格数据

# log a W&B Table that has a text caption, an image and audio
columns = ["caption", "image", "sound"]

# data should be a list of lists
my_data = [["cheese", wandb.Image(img_1), wandb.Audio(snd_1)], 
        ["wine", wandb.Image(img_2), wandb.Audio(snd_2)]]

# log the Table
wandb_logger.log_table(key="my_samples", columns=columns, data=data)

也可以使用 Lightning 的回调系统来控制何时通过 WandbLogger 记录权重和偏差,在此示例中,我们记录了验证图像和预测的样本

import torch
import wandb
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger

class LogPredictionSamplesCallback(Callback):
    
    def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
        """Called when the validation batch ends;在每个验证batch结束时调用该函数"""
 
        # `outputs` comes from `LightningModule.validation_step`
        # which corresponds to our model predictions in this case
        # outputs来自validation_step的return数据,在本例中outputs对应于模型预测
        
        # Let's log 20 sample image predictions from the first batch,只记录来自第一个batch的前 20 个样本图像预测
        if batch_idx == 0:
            n = 20
            x, y = batch
            images = [img for img in x[:n]]
            captions = [f'Ground Truth: {y_i} - Prediction: {y_pred}' 
                for y_i, y_pred in zip(y[:n], outputs[:n])]
            
            
            # Option 1: log images with `WandbLogger.log_image`
            wandb_logger.log_image(
                key='sample_images', 
                images=images, 
                caption=captions)


            # Option 2: log images and predictions as a W&B Table
            columns = ['image', 'ground truth', 'prediction']
            data = [[wandb.Image(x_i), y_i, y_pred] for x_i, y_i, y_pred in list(zip(x[:n], y[:n], outputs[:n]))]
            wandb_logger.log_table(
                key='sample_table',
                columns=columns,
                data=data)            
...

trainer = pl.Trainer(
    ...
    callbacks=[LogPredictionSamplesCallback()]
)

在多GPU的情况下使用pytorch lightning和wandb

PyTorch Lightning 通过其 DDP 接口支持多 GPU。然而,PyTorch Lightning 的设计对于“如何实例化 多个GPU”要求仔细;Lightning 假设训练循环中的每个 GPU(或 Rank)必须以完全相同的方式实例化 - 具有相同的初始条件。但是,只有 rank 0 进程可以访问 wandb.run 对象,对于非零 rank 进程:wandb.run = None;可能会导致非零进程失败。这种情况会陷入僵局,因为rank 0 级进程将等待已经崩溃的非零级进程加入。出于这个原因,必须小心设置训练代码;推荐的设置方法是让代码独立于 wandb.run 对象。

class MNISTClassifier(pl.LightningModule):
    def __init__(self):
        super(MNISTClassifier, self).__init__()

        self.model = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )
        
        self.loss = nn.CrossEntropyLoss()
    
    def forward(self, x):
        return self.model(x)
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        
        self.log("train/loss", loss)
        return {"train_loss": loss}
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.loss(y_hat, y)
        
        self.log("val/loss", loss)
        return {"val_loss": loss}
    
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)

def main():
    # Setting all the random seeds to the same value.
    # This is important in a distributed training setting. 
    # Each rank will get its own set of initial weights. 
    # If they don't match up, the gradients will not match either,
    # leading to training that may not converge.
    pl.seed_everything(1)  # 设置所有的随机种子为同一值,对于分布式训练这是重要的设置;每个进程都会有自己的一组初始权重,如果不匹配,那么梯度也将不匹配,从而导致训练可能不收敛

    train_loader = DataLoader(train_dataset,  batch_size = 64, 
                              shuffle = True, 
                              num_workers = 4)
    val_loader = DataLoader(val_dataset, 
                            batch_size = 64, 
                            shuffle = False, 
                            num_workers = 4)

    model = MNISTClassifier()
    wandb_logger = WandbLogger(project = "<project_name>")
    callbacks = [
        ModelCheckpoint(
            dirpath = "checkpoints",
            every_n_train_steps=100,
        ),
    ]
    trainer = pl.Trainer(
        max_epochs = 3, 
        gpus = 2, 
        logger = wandb_logger, 
        strategy="ddp", 
        callbacks=callbacks
    ) 
    trainer.fit(model, train_loader, val_loader)

代码案例

该代码是在一个简单的ResNet网络上修改的,其中注释进行了详细说明

import argparse
import os

import torch
from torch import nn
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning import Trainer
from torchmetrics.functional import accuracy  # pytorch lightning提供的计算各种metrics的库
from d2l import torch as d2l

os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # 使用GPU_0


# 残差块
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3,
                               padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3,
                               padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                   kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)


# 构建resnet模块
def resnet_block(input_channels, num_channels, num_residuals, first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels, use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk


# 构建五层的ResNet模型
b1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(64),
                   nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
b2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True))
b3 = nn.Sequential(*resnet_block(64, 128, 2))
b4 = nn.Sequential(*resnet_block(128, 256, 2))
b5 = nn.Sequential(*resnet_block(256, 512, 2))
net = nn.Sequential(b1, b2, b3, b4, b5, nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 10))


# 在定义trainer时,将定义的wandn_logger传入,则在定义模型时使用self.save_hyperparameters()和self.log()方法会自动将数据记录到wandb中
class PLResNet(pl.LightningModule):
    def __init__(self, args):
        super(PLResNet, self).__init__()
        # 如果不使用args传递超参数,可以在初始化时把对应的参数都赋给self,再调用self.save_hyperparameters()也能进行参数保存
        self.save_hyperparameters(args)  # 将超参数保存到self.hparams,也会自动记录到wandb中
        self.args = args

        self.net = net  # 以构建的resnet模型作为训练对象
        self.loss = nn.CrossEntropyLoss()  # 设置损失函数

    def configure_optimizers(self):  # 定义优化器
        return torch.optim.SGD(self.parameters(), lr=self.args.lr)

    def forward(self, X):  # 模型的前向计算过程
        return net(X)

    def training_step(self, batch, barch_idx):  # 训练step,即单个batch中的训练计算过程,需要返回损失
        _, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric,记录训练损失和metric,会将数据记录到wandb中
        self.log('train_loss', loss)
        self.log('train_accuracy', acc)
        return loss

    def validation_step(self, batch, barch_idx):  # 验证step
        preds, loss, acc = self._get_preds_loss_accuracy(batch)

        # Log loss and metric,记录验证损失和metric,会将数据记录到wandb中
        self.log('val_loss', loss)
        self.log('val_accuracy', acc)
        return preds

    def _get_preds_loss_accuracy(self, batch):  # 因为训练和验证的计算步骤相似,此函数计算每个batch中的损失和准确率
        X, y = batch
        y_hat = self(X)
        loss = self.loss(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        acc = accuracy(preds, y)
        return preds, loss, acc


def get_parser():
    parser = argparse.ArgumentParser()  # 简单地定义所需超参数
    parser.add_argument('--lr', type=float, default=0.03, help='learning rate')
    parser.add_argument("--batch_size", type=int, default=256, help="batch size")
    parser.add_argument("--epochs", default=5, type=int)
    return parser


def main():
    pl.seed_everything(1)

    parser = get_parser()
    args = parser.parse_args()
    model = PLResNet(args)  # 构建模型
    train_loader, val_loader = d2l.load_data_fashion_mnist(args.batch_size, resize=96)  # 加载数据

    wandb_logger = WandbLogger(project='resnet_test', name='pl1')  # 使用pytorch lightning的接口初始化wandb对象
    wandb_logger.watch(model, log='all', log_freq=10)  # 设置需要记录的数据范围
    # 定义trainer
    trainer = Trainer(max_epochs=args.epochs,
                      gpus=1,
                      logger=wandb_logger)  # 将初始化的wandb_logger设置为trainer的logger,这样在定义中self记录的数据会记录在wand中
    trainer.fit(model, train_loader, val_loader)  # 模型训练


if __name__ == '__main__':
    main()
  • 8
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值