调用bert后显存叠加的情况

如果不想训练bert,就要保证通过bert生成(包括二三次生成)的变量不能调用两次,否则可能计算梯度改变网络

例如

_, pool_out = self.bert(ids, output_all_encoded_layers=True)
comment_emb = self.linear_layer1(pool_out)
# a = comment_emb.detach()  # 解决方法,detach()共享数据内存,但不进入计算图
user_embedding = self.entity_regularizer(
    torch.index_select(self.entity_embedding, dim=0, index=queries[0]))
embedding = self.linear_layer3(torch.cat([comment_emb, user_embedding], dim=-1))
embedding = self.linear_layer3(torch.cat([comment_emb, embedding], dim=-1))
# embedding = self.linear_layer3(torch.cat([a, embedding], dim=-1))

上面的comment_emb被调用两次,最后一次调用会占30MB内存,每次训练会叠加,只调用一次不会叠加

解决方法
换个变量名或者直接detach(),这里不能用clone()

最好的办法还是训练模型前先生成bert模型的向量,再进行输入

如果是其他显存爆炸的问题,这里还有一个打印显存的办法,找显存爆炸的地方再看看怎么改

import os
import psutil


def get_gpu_mem_info(gpu_id=0):
    """
    根据显卡 id 获取显存使用信息, 单位 MB
    :param gpu_id: 显卡 ID
    :return: total 所有的显存,used 当前使用的显存, free 可使用的显存
    """
    import pynvml
    pynvml.nvmlInit()
    if gpu_id < 0 or gpu_id >= pynvml.nvmlDeviceGetCount():
        print(r'gpu_id {} 对应的显卡不存在!'.format(gpu_id))
        return 0, 0, 0

    handler = pynvml.nvmlDeviceGetHandleByIndex(gpu_id)
    meminfo = pynvml.nvmlDeviceGetMemoryInfo(handler)
    total = round(meminfo.total / 1024 / 1024, 2)
    used = round(meminfo.used / 1024 / 1024, 2)
    free = round(meminfo.free / 1024 / 1024, 2)
    return total, used, free


def get_cpu_mem_info():
    """
    获取当前机器的内存信息, 单位 MB
    :return: mem_total 当前机器所有的内存 mem_free 当前机器可用的内存 mem_process_used 当前进程使用的内存
    """
    mem_total = round(psutil.virtual_memory().total / 1024 / 1024, 2)
    mem_free = round(psutil.virtual_memory().available / 1024 / 1024, 2)
    mem_process_used = round(psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024, 2)
    return mem_total, mem_free, mem_process_used


if __name__ == "__main__":

    gpu_mem_total, gpu_mem_used, gpu_mem_free = get_gpu_mem_info(gpu_id=0)
    print(r'当前显卡显存使用情况:总共 {} MB, 已经使用 {} MB, 剩余 {} MB'
          .format(gpu_mem_total, gpu_mem_used, gpu_mem_free))

    cpu_mem_total, cpu_mem_free, cpu_mem_process_used = get_cpu_mem_info()
    print(r'当前机器内存使用情况:总共 {} MB, 剩余 {} MB, 当前进程使用的内存 {} MB'
          .format(cpu_mem_total, cpu_mem_free, cpu_mem_process_used))
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值