Pytorch模型测试时显存一直上升导致爆显存

问题描述

首先说明: 由于我的测试集很大, 因此需要对测试集进行分批次推理.

在写代码的时候发现进行训练的时候大概显存只占用了2GB左右, 而且训练过程中显存占用量也基本上是不变的. 而在测试的时候, 发现显存在每个batch数据推理后逐渐增加, 直至最后导致爆显存, 程序fail.

这里放一下我测试的代码:

	y, y_ = torch.Tensor(), torch.Tensor()
        for batch in tqdm(loader):
            x, batch_y = batch[0], batch[1]
            batch_y_ = model(x)
            y = torch.cat([y, batch_y], dim=0)
            y_ = torch.cat([y_, batch_y_], dim=0)

解决方法

遇到问题后我就进行单步调试, 然后观察显存的变化. 发现在模型推理这一步, 每一轮次显存都会增加.

batch_y_ = model(x)

这里令人费解的是, 模型推理实际上在训练和测试中都是存在的, 为什么训练的时候就不会出现这个问题呢.

最后发现其实是在训练的时候有这样一步与测试不同:

self.optimizer.zero_grad()

⭐️ 在训练时, 每一个batch后都会将模型的梯度进行一次清零. 而测试的时候我则没有加这一步, 这样的话每次模型再做推理的时候都会产生新的梯度, 并累积到显存当中.

清楚了问题, 那么解决方法也就随之而来, 在测试的时候让模型不要记录梯度就好, 因为其实也用不到:

with torch.no_grad():
	test()

总结

  • 训练的时候每一个batch结束除了梯度反向传播, 还要提前清理梯度
  • 梯度如果不清理的话, 会在显存中累积下来
### PyTorch 训练在最后一个批次卡住的原因分析 当使用 PyTorch 进行训练,如果发现程序在处理最后一个批次卡住,可能是由多种原因引起的。以下是几个常见的可能性及其解决方案: #### 1. 数据加载器配置不当 数据加载器(`DataLoader`)可能导致问题的一个常见原因是其 `drop_last` 参数未正确设置。默认情况下,`DataLoader` 不会丢弃任何剩余的数据片段,即使这些片段不足以构成一个完整的批次。这可能会导致某些操作无法正常执行,尤其是在自定义模型或损失函数中。 可以通过调整 `DataLoader` 的参数来解决问题: ```python train_loader = torch.utils.data.DataLoader( dataset=train_dataset, batch_size=32, shuffle=True, drop_last=True # 如果不需要最后不足一批次的数据,则可以启用此选项 ) ``` #### 2. 自定义损失函数中的潜在错误 有,自定义的损失函数可能未能适当地处理不完整批次的情况。例如,假设某个损失函数期望输入张量具有固定大小,但在最后一个批次中由于数据数量较少而导致形状变化,从而引发异常或死锁现象[^4]。 为了验证这一点,可以在每次迭代前打印当前批次的尺寸以确认一致性: ```python for i, (inputs, labels) in enumerate(train_loader): print(f"Iteration {i}, Input shape: {inputs.shape}") outputs = model(inputs) loss = custom_loss_function(outputs, labels) loss.backward() optimizer.step() ``` #### 3. GPU 内存管理问题 另一个可能的因素涉及图形处理器(GPU)内存分配策略。如果网络非常大或者批量特别庞大,在接近结束阶段可能出现资源耗尽状况,进而阻止进一步计算完成。因此建议监控GPU利用率并适当减少批处理规模或其他消耗大量显存的部分。 此外还可以尝试通过调用 `.cpu()` 方法将变量移回CPU来进行调试诊断;不过需要注意这样做会影响性能表现。 ```python if inputs.is_cuda: inputs = inputs.cpu() # 将数据移到 CPU 上测试是否存在其他隐含问题 ``` #### 4. 并行化多线程读取冲突 利用多进程方式加速数据预取虽然有效提升效率但也增加了复杂度。如果有同步障碍存在于不同子进程中就容易造成阻塞等待情形发生。对此情况下的解决办法之一就是降低num_workers数目甚至设为0暂关闭异步机制观察效果如何改变。 ```python train_loader = torch.utils.data.DataLoader( ... num_workers=0, # 关闭多线程模式用于排查是否与此有关联 pin_memory=False ) ``` --- ### 总结 综上所述,PyTorch 在最后一轮训练期间停滞的现象可以从上述四个方面逐一排查定位根本诱因所在,并采取相应措施予以修复。值得注意的是,选择合适的优化器同样重要,因为它直接影响到整个学习过程能否顺利开展下去而不受阻碍 [^2].
评论 9
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值