LightningCLI使用指南(一)

官方文档:https://lightning.ai/docs/overview/getting-started

学习仓库:https://github.com/zhangyiqian001/learndeep.git

Vit架构:https://blog.csdn.net/weixin_43457608/article/details/147295658?sharetype=blogdetail&sharerId=147295658&sharerefer=PC&sharesource=weixin_43457608&spm=1011.2480.3001.8118

启动命令

from lightning.pytorch.cli import LightningCLI

# python .\main.py fit -c .\config.yaml
if __name__ == '__main__':
    cli = LightningCLI(
        save_config_kwargs={"save_to_log_dir": True, "overwrite": True}
    )

BaseDataModule类

from lightning import LightningDataModule
from torch.utils.data import DataLoader


class BaseDataModule(LightningDataModule):

    def __init__(
            self,
            batch_size,
            num_workers
    ):
        super().__init__()

        self.train_set = None
        self.val_set = None
        self.test_set = None
        self.predict_set = None

        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transfer = None
        # 缓存workers
        self.persistent_workers = True
        # 如果数据集大小不能被批处理大小整除,删除最后一个未完成的批
        self.drop_last = True
        self.collate_fn = None

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            self.train_set,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            drop_last=self.drop_last,
            collate_fn=self.collate_fn
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            self.val_set,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            drop_last=self.drop_last,
            collate_fn=self.collate_fn
        )

    def test_dataloader(self) -> DataLoader:
        return DataLoader(
            self.test_set,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            persistent_workers=self.persistent_workers,
            drop_last=self.drop_last,
            collate_fn=self.collate_fn
        )

自定义数据集

import glob
import os

from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from torchvision import transforms

from datamodules.base_datamodule import BaseDataModule


class CatsDogsDataset(Dataset):
    def __init__(self, file_list, transform=None):
        self.file_list = file_list
        self.transform = transform

    def __len__(self):
        self.filelength = len(self.file_list)
        return self.filelength

    def __getitem__(self, idx):
        img_path = self.file_list[idx]
        img = Image.open(img_path)
        img_transformed = self.transform(img)

        label = img_path.split("/")[-1].split(".")[0]
        label = 1 if label == "dog" else 0

        return img_transformed, label


class CatsAndDogsDataModule(BaseDataModule):
    def __init__(self, kwargs, ):
        super().__init__(
            kwargs.pop('batch_size'),
            kwargs.pop('num_workers')
        )
        self.kwargs = kwargs

    def prepare_data(self) -> None:
        pass

    def setup(self, stage: str) -> None:
        train_list = glob.glob(os.path.join(self.kwargs['root'], 'train', '*.jpg'))
        test_list = glob.glob(os.path.join(self.kwargs['root'], 'test', '*.jpg'))
        labels = [path.split('/')[-1].split('.')[0] for path in train_list]
        train_list, valid_list = train_test_split(train_list,
                                                  test_size=0.2,
                                                  stratify=labels,
                                                  random_state=42)
        train_transforms = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
            ]
        )

        val_transforms = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]
        )

        test_transforms = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
            ]
        )
        self.train_set = CatsDogsDataset(train_list, transform=train_transforms)
        self.val_set = CatsDogsDataset(valid_list, transform=test_transforms)
        self.test_set = CatsDogsDataset(test_list, transform=test_transforms)

BaseModule类

class BaseModule(LightningModule):

    def __init__(self):
        super().__init__()

    def forward(self, kwargs: dict) -> Tensor:
        return self.model(**kwargs)


class BaseClassificationModule(BaseModule):

    def training_step(self, batch, batch_idx: int) -> Tensor:
        if self.loss is None:
            output = self(batch)
            loss = output
            values = {"train_loss": loss}
        else:
            output = self(batch)
            loss = self.loss(output, batch['targets'])
            acc = self.metrics(output.argmax(1), batch['targets'])
            values = {"train_loss": loss, "train_acc": acc}
        self.log_dict(values, prog_bar=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx: int) -> Tensor:
        if self.loss is None:
            output = self(batch)
            loss = output
            values = {"train_loss": loss}
        else:
            output = self(batch)
            loss = self.loss(output, batch['targets'])
            acc = self.metrics(output.argmax(1), batch['targets'])
            values = {"train_loss": loss, "train_acc": acc}
        try:
            sample_imgs = batch['inputs'][:5]
            grid = torchvision.utils.make_grid(sample_imgs)
            self.loggers[0].experiment.add_image('example_images', grid, 0)
        except Exception as e:
            print(e)
            print("dataset is not images")
        self.log_dict(values, prog_bar=True, sync_dist=True)
        return loss

    def test_step(self, batch: Any, batch_idx):
        inputs = batch['inputs']
        output = self(inputs)
        return output

定义Vit模型

class ViTModel(nn.Module):
    def __init__(
            self,
            *,
            image_size=224,
            patch_size=32,
            num_classes=2,
            dim=128,
            depth=12,
            heads=8,
            mlp_dim=512,
            pool='cls',
            channels=3,
            dim_head=64,
            dropout=0.,
            emb_dropout=0.
    ):
        super().__init__()
        image_height, image_width = pair(image_size)
        patch_height, patch_width = pair(patch_size)

        assert image_height % patch_height == 0 and image_width % patch_width == 0, 'Image dimensions must be divisible by the patch size.'

        num_patches = (image_height // patch_height) * (image_width // patch_width)
        patch_dim = channels * patch_height * patch_width
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.to_patch_embedding = nn.Sequential(
            # b, patch, pixel
            # torch.Size([64, 49, 3072])
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.LayerNorm(patch_dim),
            nn.Linear(patch_dim, dim),
            nn.LayerNorm(dim),
        )

        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(emb_dropout)

        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)

        self.pool = pool
        self.to_latent = nn.Identity()

        self.mlp_head = nn.Linear(dim, num_classes)

    def forward(self, img):
        # torch.Size([64, 3, 224, 224])
        x = self.to_patch_embedding(img)
        # torch.Size([64, 49, 128])
        b, n, _ = x.shape

        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        # torch.Size([64, 50, 128])
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embedding[:, :(n + 1)]
        x = self.dropout(x)

        x = self.transformer(x)

        x = x.mean(dim=1) if self.pool == 'mean' else x[:, 0]

        x = self.to_latent(x)
        return self.mlp_head(x)


class ViTModule(BaseClassificationModule):
    def __init__(self, model, loss, metrics):
        super().__init__()
        self.model = model
        self.loss = loss
        self.metrics = metrics

    def on_before_batch_transfer(self, batch, dataloader_idx: int):
        return {
            "inputs": batch[0],
            "targets": batch[1],
        }

    def transfer_batch_to_device(self, batch, device: torch.device, dataloader_idx: int):
        result = {}
        for key, value in batch.items():
            if isinstance(value, dict):
                result[key] = {k: v.to(device) for k, v in value}
            else:
                result[key] = value.to(device)
        return result

常用配置

model:
  class_path: modules.vit_module.ViTModule
  init_args:
    model:
      class_path: modules.vit_module.ViTModel
      init_args:
        image_size: 224
        patch_size: 32
        num_classes: &num_classes 2
        dim: 128
        depth: 12
        heads: 8
        mlp_dim: 512
        pool: 'cls'
        channels: 3
        dim_head: 64
        dropout: 0.
        emb_dropout: 0.
    loss:
      class_path: torch.nn.CrossEntropyLoss
    metrics:
      class_path: torchmetrics.Accuracy
      init_args:
        task: multiclass
        num_classes: *num_classes


data:
  class_path: datamodules.CatsAndDogsDataModule
  init_args:
    kwargs:
      root: data/dogs-vs-cats/
      batch_size: 64
      num_workers: 4

训练配置文件说明

# lightning.pytorch==2.5.1
# Runs the full optimization routine.

# Set to an int to run seed_everything with this value before classes instantiation.Set to True to use a random seed. (type: Union[bool, int], default: True)
seed_everything: true

# Customize every aspect of training via flags
trainer:
  # Supports passing different accelerator types ("cpu", "gpu", "tpu", "hpu", "mps", "auto")
  # as well as custom accelerator instances. (type: Union[str, Accelerator], default: auto, known subclasses: lightning.pytorch.accelerators.CPUAccelerator, lightning.pytorch.accelerators.CUDAAccelerator, lightning.pytorch.accelerators.MPSAccelerator, lightning.pytorch.accelerators.XLAAccelerator)
  accelerator: auto

  # Supports different training strategies with aliases as well custom strategies.
  # Default: ``"auto"``. (type: Union[str, Strategy], default: auto, known subclasses: lightning.pytorch.strategies.DDPStrategy, lightning.pytorch.strategies.DeepSpeedStrategy,
  # lightning.pytorch.strategies.XLAStrategy, lightning.pytorch.strategies.FSDPStrategy, lightning.pytorch.strategies.ModelParallelStrategy, lightning.pytorch.strategies.SingleDeviceStrategy, lightning.pytorch.strategies.SingleDeviceXLAStrategy)
  strategy: auto

  # The devices to use. Can be set to a positive number (int or str), a sequence of device indices
  # (list or str), the value ``-1`` to indicate all available devices should be used, or ``"auto"`` for
  # automatic selection based on the chosen accelerator. Default: ``"auto"``. (type: Union[list[int], str, int], default: auto)
  devices: auto

  # Number of GPU nodes for distributed training.
  # Default: ``1``. (type: int, default: 1)
  num_nodes: 1

  # Double precision (64, '64' or '64-true'), full precision (32, '32' or '32-true'),
  # 16bit mixed precision (16, '16', '16-mixed') or bfloat16 mixed precision ('bf16', 'bf16-mixed').
  # Can be used on CPU, GPU, TPUs, or HPUs.
  # Default: ``'32-true'``. (type: Union[Literal[64, 32, 16], Literal['transformer-engine', 'transformer-engine-float16', '16-true', '16-mixed', 'bf16-true', 'bf16-mixed', '32-true', '64-true'], Literal['64', '32', '16', 'bf16'], null], default: null)
  precision:

  # Logger (or iterable collection of loggers) for experiment tracking. A ``True`` value uses
  # the default ``TensorBoardLogger`` if it is installed, otherwise ``CSVLogger``.
  # ``False`` will disable logging. If multiple loggers are provided, local files
  # (checkpoints, profiler traces, etc.) are saved in the ``log_dir`` of the first logger.
  # Default: ``True``. (type: Union[Logger, Iterable[Logger], bool, null], default: null, known subclasses: lightning.pytorch.loggers.logger.DummyLogger, lightning.pytorch.logg
  # ers.CometLogger, lightning.pytorch.loggers.CSVLogger, lightning.pytorch.loggers.MLFlowLogger, lightning.pytorch.loggers.NeptuneLogger, lightning.pytorch.loggers.TensorBoardLogger, lightning.pytorch.loggers.WandbLogger)
  logger:

  # Add a callback or list of callbacks.
  # Default: ``None``. (type: Union[list[Callback], Callback, null], default: null, known subclasses: lightning.Callback, lightning.pytorch.callbacks.BatchSizeFinder, lightning
  #.pytorch.callbacks.Checkpoint, lightning.pytorch.callbacks.ModelCheckpoint, lightning.pytorch.callbacks.OnExceptionCheckpoint, lightning.pytorch.callbacks.DeviceStatsMonitor, l
  #ightning.pytorch.callbacks.EarlyStopping, lightning.pytorch.callbacks.BaseFinetuning, lightning.pytorch.callbacks.BackboneFinetuning, lightning.pytorch.callbacks.GradientAccumu
  #lationScheduler, lightning.pytorch.callbacks.LambdaCallback, lightning.pytorch.callbacks.LearningRateFinder, lightning.pytorch.callbacks.LearningRateMonitor, lightning.pytorch.
  #callbacks.ModelSummary, lightning.pytorch.callbacks.RichModelSummary, lightning.pytorch.callbacks.BasePredictionWriter, lightning.pytorch.callbacks.ProgressBar, lightning.pytor
  #ch.callbacks.RichProgressBar, lightning.pytorch.callbacks.TQDMProgressBar, lightning.pytorch.callbacks.Timer, lightning.pytorch.callbacks.ModelPruning, lightning.pytorch.callbacks.SpikeDetection, lightning.pytorch.callbacks.StochasticWeightAveraging, lightning.pytorch.callbacks.ThroughputMonitor, lightning.pytorch.cli.SaveConfigCallback)
  callbacks:

  # Runs n if set to ``n`` (int) else 1 if set to ``True`` batch(es)
  # of train, val and test to find any bugs (ie: a sort of unit test).
  # Default: ``False``. (type: Union[int, bool], default: False)
  fast_dev_run: false

  # Stop training once this number of epochs is reached. Disabled by default (None).
  # If both max_epochs and max_steps are not specified, defaults to ``max_epochs = 1000``.
  # To enable infinite training, set ``max_epochs = -1``. (type: Optional[int], default: null)
  max_epochs:

  # Force training for at least these many epochs. Disabled by default (None). (type: Optional[int], default: null)
  min_epochs:

  # Stop training after this number of steps. Disabled by default (-1). If ``max_steps = -1``
  # and ``max_epochs = None``, will default to ``max_epochs = 1000``. To enable infinite training, set
  # ``max_epochs`` to ``-1``. (type: int, default: -1)
  max_steps: -1

  # Force training for at least these number of steps. Disabled by default (``None``). (type: Optional[int], default: null)
  min_steps:

  # Stop training after this amount of time has passed. Disabled by default (``None``).
  # The time duration can be specified in the format DD:HH:MM:SS (days, hours, minutes seconds), as a
  # :class:`datetime.timedelta`, or a dictionary with keys that will be passed to
  # :class:`datetime.timedelta`. (type: Union[str, timedelta, dict[str, int], null], default: null)
  max_time:

  # How much of training dataset to check (float = fraction, int = num_batches).
  # Default: ``1.0``. (type: Union[int, float, null], default: null)
  limit_train_batches:

  # How much of validation dataset to check (float = fraction, int = num_batches).
  # Default: ``1.0``. (type: Union[int, float, null], default: null)
  limit_val_batches:

  # How much of test dataset to check (float = fraction, int = num_batches).
  # Default: ``1.0``. (type: Union[int, float, null], default: null)
  limit_test_batches:

  # How much of prediction dataset to check (float = fraction, int = num_batches).
  # Default: ``1.0``. (type: Union[int, float, null], default: null)
  limit_predict_batches:

  # Overfit a fraction of training/validation data (float) or a set number of batches (int).
  # Default: ``0.0``. (type: Union[int, float], default: 0.0)
  overfit_batches: 0.0

  # How often to check the validation set. Pass a ``float`` in the range [0.0, 1.0] to check
  # after a fraction of the training epoch. Pass an ``int`` to check after a fixed number of training
  # batches. An ``int`` value can only be higher than the number of training batches when
  # ``check_val_every_n_epoch=None``, which validates after every ``N`` training batches
  # across epochs or during iteration-based training.
  # Default: ``1.0``. (type: Union[int, float, null], default: null)
  val_check_interval:

  # Perform a validation loop after every `N` training epochs. If ``None``,
  # validation will be done solely based on the number of training batches, requiring ``val_check_interval``
  # to be an integer value.
  # Default: ``1``. (type: Optional[int], default: 1)
  check_val_every_n_epoch: 1

  # Sanity check runs n validation batches before starting the training routine.
  # Set it to `-1` to run all batches in all validation dataloaders.
  # Default: ``2``. (type: Optional[int], default: null)
  num_sanity_val_steps:

  # How often to log within steps.
  # Default: ``50``. (type: Optional[int], default: null)
  log_every_n_steps:

  # If ``True``, enable checkpointing.
  # It will configure a default ModelCheckpoint callback if there is no user-defined ModelCheckpoint in
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.callbacks`.
  # Default: ``True``. (type: Optional[bool], default: null)
  enable_checkpointing:

  # Whether to enable to progress bar by default.
  # Default: ``True``. (type: Optional[bool], default: null)
  enable_progress_bar:

  # Whether to enable model summarization by default.
  # Default: ``True``. (type: Optional[bool], default: null)
  enable_model_summary:

  # Accumulates gradients over k batches before stepping the optimizer.
  # Default: 1. (type: int, default: 1)
  accumulate_grad_batches: 1

  # The value at which to clip gradients. Passing ``gradient_clip_val=None`` disables
  # gradient clipping. If using Automatic Mixed Precision (AMP), the gradients will be unscaled before.
  # Default: ``None``. (type: Union[int, float, null], default: null)
  gradient_clip_val:

  # The gradient clipping algorithm to use. Pass ``gradient_clip_algorithm="value"``
  # to clip by value, and ``gradient_clip_algorithm="norm"`` to clip by norm. By default it will
  # be set to ``"norm"``. (type: Optional[str], default: null)
  gradient_clip_algorithm:

  # If ``True``, sets whether PyTorch operations must use deterministic algorithms.
  # Set to ``"warn"`` to use deterministic algorithms whenever possible, throwing warnings on operations
  # that don't support deterministic mode. If not set, defaults to ``False``. Default: ``None``. (type: Union[bool, Literal['warn'], null], default: null)
  deterministic:

  # The value (``True`` or ``False``) to set ``torch.backends.cudnn.benchmark`` to.
  # The value for ``torch.backends.cudnn.benchmark`` set in the current session will be used
  # (``False`` if not manually set). If :paramref:`~lightning.pytorch.trainer.trainer.Trainer.deterministic`
  # is set to ``True``, this will default to ``False``. Override to manually set a different value.
  # Default: ``None``. (type: Optional[bool], default: null)
  benchmark:

  # Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` during
  # evaluation (``validate``/``test``/``predict``). (type: bool, default: True)
  inference_mode: true

  # Whether to wrap the DataLoader's sampler with
  # :class:`torch.utils.data.DistributedSampler`. If not specified this is toggled automatically for
  # strategies that require it. By default, it will add ``shuffle=True`` for the train sampler and
  # ``shuffle=False`` for validation/test/predict samplers. If you want to disable this logic, you can pass
  # ``False`` and add your own distributed sampler in the dataloader hooks. If ``True`` and a distributed
  # sampler was already added, Lightning will not replace the existing one. For iterable-style datasets,
  # we don't do this automatically. (type: bool, default: True)
  use_distributed_sampler: true

  # To profile individual steps during training and assist in identifying bottlenecks.
  # Default: ``None``. (type: Union[Profiler, str, null], default: null, known subclasses: lightning.pytorch.profilers.AdvancedProfiler, lightning.pytorch.profilers.PassThroughProfiler, lightning.pytorch.profilers.PyTorchProfiler, lightning.pytorch.profilers.SimpleProfiler, lightning.pytorch.profilers.XLAProfiler)
  profiler:

  # Enable anomaly detection for the autograd engine.
  # Default: ``False``. (type: bool, default: False)
  detect_anomaly: false

  # Whether to run in "barebones mode", where all features that may impact raw speed are
  # disabled. This is meant for analyzing the Trainer overhead and is discouraged during regular training
  # runs. The following features are deactivated:
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_checkpointing`,
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.logger`,
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_progress_bar`,
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.log_every_n_steps`,
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.enable_model_summary`,
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.num_sanity_val_steps`,
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.fast_dev_run`,
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.detect_anomaly`,
  # :paramref:`~lightning.pytorch.trainer.trainer.Trainer.profiler`,
  # :meth:`~lightning.pytorch.core.LightningModule.log`,
  # :meth:`~lightning.pytorch.core.LightningModule.log_dict`. (type: bool, default: False)
  barebones: false

  # Plugins allow modification of core behavior like ddp and amp, and enable custom lightning plugins.
  # Default: ``None``. (type: Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync, list[Union[Precision, ClusterEnvironment, CheckpointIO, LayerSync]], null], default:
  #    null, known subclasses: lightning.pytorch.plugins.Precision, lightning.pytorch.plugins.MixedPrecision, lightning.pytorch.plugins.BitsandbytesPrecision, lightning.pytorch.plugi
  #  ns.DeepSpeedPrecision, lightning.pytorch.plugins.DoublePrecision, lightning.pytorch.plugins.FSDPPrecision, lightning.pytorch.plugins.HalfPrecision, lightning.pytorch.plugins.Tr
  #  ansformerEnginePrecision, lightning.pytorch.plugins.XLAPrecision, lightning.fabric.plugins.environments.KubeflowEnvironment, lightning.fabric.plugins.environments.LightningEnvi
  #  ronment, lightning.fabric.plugins.environments.LSFEnvironment, lightning.fabric.plugins.environments.MPIEnvironment, lightning.fabric.plugins.environments.SLURMEnvironment, lig
  #  htning.fabric.plugins.environments.TorchElasticEnvironment, lightning.fabric.plugins.environments.XLAEnvironment, lightning.fabric.plugins.TorchCheckpointIO, lightning.fabric.plugins.XLACheckpointIO, lightning.pytorch.plugins.AsyncCheckpointIO, lightning.pytorch.plugins.TorchSyncBatchNorm)
  plugins:

  # Synchronize batch norm layers between process groups/whole world.
  # Default: ``False``. (type: bool, default: False)
  sync_batchnorm: false

  # Set to a positive integer to reload dataloaders every n epochs.
  # Default: ``0``. (type: int, default: 0)
  reload_dataloaders_every_n_epochs: 0

  # Default path for logs and weights when no logger/ckpt_callback passed.
  # Default: ``os.getcwd()``.
  # Can be remote file paths such as `s3://mybucket/path` or 'hdfs://path/' (type: Union[str, Path, null], default: null)
  default_root_dir:

  # The name of the model being uploaded to Model hub. (type: Optional[str], default: null)
  model_registry:

# One or more arguments specifying "class_path" and "init_args" for any subclass of LightningModule. (type: Optional[LightningModule], default: null, known subclasses: lightning.LightningModule)
model:

# One or more arguments specifying "class_path" and "init_args" for any subclass of LightningDataModule. (type: Optional[LightningDataModule], default: null, known subclasses: lightning.LightningDataModule)
data:

# One or more arguments specifying "class_path" and "init_args" for any subclass of Optimizer. (type: Optional[Optimizer], default: null, known subclasses: torch.optim.Optimize
# r, torch.optim.Adafactor, torch.optim.Adadelta, torch.optim.Adagrad, torch.optim.Adam, torch.optim.Adamax, torch.optim.AdamW, torch.optim.ASGD, torch.optim.LBFGS, torch.optim.NAdam, torch.optim.RAdam, torch.optim.RMSprop, torch.optim.Rprop, torch.optim.SGD, torch.optim.SparseAdam)
optimizer:

# One or more arguments specifying "class_path" and "init_args" for any subclass of {LRScheduler,ReduceLROnPlateau}. (type: Union[LRScheduler, ReduceLROnPlateau, null], default
#  : null, known subclasses: torch.optim.lr_scheduler.LRScheduler, torch.optim.lr_scheduler.LambdaLR, torch.optim.lr_scheduler.MultiplicativeLR, torch.optim.lr_scheduler.StepLR, t
#  orch.optim.lr_scheduler.MultiStepLR, torch.optim.lr_scheduler.ConstantLR, torch.optim.lr_scheduler.LinearLR, torch.optim.lr_scheduler.ExponentialLR, torch.optim.lr_scheduler.Se
#  quentialLR, torch.optim.lr_scheduler.PolynomialLR, torch.optim.lr_scheduler.CosineAnnealingLR, torch.optim.lr_scheduler.ChainedScheduler, torch.optim.lr_scheduler.ReduceLROnPla
#  teau, lightning.pytorch.cli.ReduceLROnPlateau, torch.optim.lr_scheduler.CyclicLR, torch.optim.lr_scheduler.CosineAnnealingWarmRestarts, torch.optim.lr_scheduler.OneCycleLR, torch.optim.swa_utils.SWALR)
lr_scheduler:

# Path/URL of the checkpoint from which training is resumed. Could also be one of two special
# keywords ``"last"``, ``"hpc"`` and ``"registry"``.
# Otherwise, if there is no checkpoint file at the path, an exception is raised.

#     - best: the best model checkpoint from the previous ``trainer.fit`` call will be loaded
#     - last: the last model checkpoint from the previous ``trainer.fit`` call will be loaded
#     - registry: the model will be downloaded from the Lightning Model Registry with following notations:

#         - ``'registry'``: uses the latest/default version of default model set
#           with ``Tainer(..., model_registry="my-model")``
#         - ``'registry:model-name'``: uses the latest/default version of this model `model-name`
#         - ``'registry:model-name:version:v2'``: uses the specific version 'v2' of the model `model-name`
#         - ``'registry:version:v2'``: uses the default model set
#           with ``Tainer(..., model_registry="my-model")`` and version 'v2' (type: Union[str, Path, null], default: null)
ckpt_path:

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值