tensorflow内存泄漏或模型只加载不运行

使用tf2模型进行推理的过程中,发现模型的内存占用在逐步增加,甚至会因为OOM被kill掉进程,有时候模型只加载不运行,搜索得到很多五花八门的答案,有些认为是tf2本身的问题,但在使用内存追踪的时候发现,是模型的动态图没有得到释放,而导致这个问题出现的原因,是数据的加载方式存在问题!!!

        mhc_a_batches = list(chunks(mhc_seqs_a, self.batch_size))
        mhc_b_batches = list(chunks(mhc_seqs_b, self.batch_size))
        pep_batches = list(chunks(pep_seqs, self.batch_size))
        assert len(mhc_a_batches) == len(mhc_b_batches)
        assert len(mhc_a_batches) == len(pep_batches)
        size = len(mhc_a_batches)
        
        # 开始预测
        preds = []
        for i in range(size):
            _preds = self.model([mhc_a_batches[i], mhc_b_batches[i], pep_batches[i]], training = False)
            preds.extend(_preds.numpy().tolist())
        return preds

如这段代码,直接使用了list作为模型的输入,尽管tf2也支持numpy的输入格式,但却存在隐患,会产生大量的空tensor!!!

将其改为这样的形式,问题得到解决:

 mhc_seqs_a = tf.convert_to_tensor(mhc_seqs_a, dtype=tf.float32)
        mhc_seqs_b = tf.convert_to_tensor(mhc_seqs_b, dtype=tf.float32)
        pep_seqs   = tf.convert_to_tensor(pep_seqs, dtype=tf.float32)
        
        assert len(mhc_seqs_a) == len(mhc_seqs_b)
        assert len(mhc_seqs_a) == len(pep_seqs)
        
        ds = tf.data.Dataset.from_tensor_slices((mhc_seqs_a, mhc_seqs_b, pep_seqs)).batch(self.batch_size).prefetch(1)
        
        preds = []
        for x, y, z in ds:
            _preds = self.model([x,y,z], training=False)
            preds.extend(_preds.numpy().tolist())
        return preds

现在可以愉快的进行模型推理了,而且速度比之前要快几倍不止,实测在GPU上提速近30倍,可想而知对于上亿级别的数据,节省的时间多么可观!

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值