Pytorch-Lightning在模型训练时记录中间值报错RuntimeError:CUDA out of memory

在训练model的时候,Pytorch-Lightning通过定义System(以model作为输入)和Trainer的方式实现模型训练。

一、定义

(1)首先定义System:

res_from_recursion = []
system = System(
        model=model,#在system之前定义的
        loss_func=loss_func,
        optimizer=optimizer,
        train_loader=train_loader,
        res_from_recursion=res_from_recursion,#我想让模型在训练时输出的值,用一个list保存
        val_loader=val_loader,
        scheduler=scheduler,
        config=conf,
    )

(2)然后定义Trainer:

 trainer = pl.Trainer(
        max_epochs=conf["training"]["epochs"],
        callbacks=callbacks,
        default_root_dir=exp_dir,
        gpus=gpus,
        distributed_backend=distributed_backend,
        limit_train_batches=1.0,  # Useful for fast experiment
        gradient_clip_val=5.0,
        checkpoint_callback=checkpoint,
    )

(3)最后训练:

   trainer.fit(system)

二、修改System的内容

(4)在1.1中,System的输入参数被改变了,所以要在具体函数定义和初始化的地方也做修改。我这里是在asteroid/engine/system.py文件中

def __init__(
        self,
        model,
        optimizer,
        loss_func,
        train_loader,
        res_from_recursion=None, # !!
        val_loader=None,
        scheduler=None,
        config=None,
    ):
        super().__init__()
        self.model = model
        self.optimizer = optimizer
        self.loss_func = loss_func
        self.train_loader = train_loader
        self.res_from_recursion = res_from_recursion # !!
        self.val_loader = val_loader
        self.scheduler = scheduler
        self.config = {} if config is None else config
        # Save lightning's AttributeDict under self.hparams
        self.save_hyperparameters(self.config_to_hparams(self.config))

(5)然后在模型训练的时候执行res_from_recursion.append(xxx),一直记录每次的中间值。但是在训练时报错RuntimeError:CUDA out of memory,但其实模型size很小(我记录的这个变量,每个input时只产生6个数)

我遇到这个问题的原因是,我在执行res_from_recursion.append(xxx)时,xxx还是tensor,没有转成float、int等。

因此,修改为:

cur_res_from_recursion.append(xxx.item())
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值