Tensorflow2数据集过大,GPU内存不够

前言:
在我们平时使用tensorflow训练模型时,有时候可能因为数据集太大(比如VOC数据集等等)导致GPU内存不够导致终止,可以自制一个数据生成器来解决此问题。

代码如下:

def train_generator(train_path,train_labels,batch):
    over=len(train_path)%batch
    while True:
        for i in range(0,len(train_path)-over,batch):
            train_data=read_img(train_path[i:i+batch])
            train_label=train_labels[i:i+batch]
            yield (np.array(train_data), np.array(train_label))

方法就是将数据集图片的路径保存到一个列表之中,然后使用while循环在训练时进行不断读取,这里over的作用是防止图片长度不是batch整数倍,导致label的数据长度不等于batch,我在训练时出现了这样的问题,这是我的猜测。然后yield与return的不同是,return是在函数执行到return就会退出函数,而yield则不会退出函数,所以使用yield
最后一句话也可以改成:

yield ({'input':np.array(train_data)}, {'output':np.array(train_label)})

'input’是你网络第一层的名字.。
'output’是你网络最后一层的名字。

接下来是使用代码:

history=model.fit(train_generator(train_data,train_label,batch=Yolo_param.Batch_size),
          batch_size=Yolo_param.Batch_size,
          epochs=10,
          steps_per_epoch=1024,
          validation_steps=32,
          callbacks=[callback],
          validation_data=train_generator(test_data,test_label,batch=Yolo_param.Batch_size))

steps_per_epoch这个参数是每个epoch的数据大小,如果不给进度就能难显示。

最后就是显存设置:

gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    tf.config.set_logical_device_configuration(
        gpus[0],
        [tf.config.LogicalDeviceConfiguration(memory_limit=4096)])
  except RuntimeError as e:
    print(e)

4096就是你限制显卡内存的大小,可以根据自己显卡实际情况来进行设置

  • 9
    点赞
  • 21
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
当训练很小的数据集时,出现GPU内存不足的问题可能有以下几个原因: 1. 数据集过大:虽然数据集本身很小,但可能在加载数据时进行了一些处理或者转换,导致数据集的大小变大。在加载数据时,可以尝试使用tf.data.Dataset进行懒加载,从而减少内存占用。 2. 模型过复杂:即使数据集很小,如果模型非常复杂,也会导致内存不足。可以尝试减小模型的大小,例如减少网络层数、减小每层的神经元数量,或者使用更轻量级的模型。 3. 运算过程中内存占用高:在训练过程中,如果使用了大量的中间变量或者计算图过于复杂,也会导致内存占用增加。可以尝试优化计算图,减少中间变量的使用,或者使用更高效的计算方式,如使用tf.function进行静态图编译。 4. TensorFlow版本问题:某些版本的TensorFlow内存的使用不够优化,建议升级到最新版本,或者考虑使用其他更轻量级的深度学习框架。 针对以上问题,可以尝试以下解决方案: - 在训练过程中使用批量训练,即每次只加载一小部分数据进行训练,可以使用tf.data.Dataset.batch()方法实现。 - 使用较低的数据类型,如使用tf.float16代替tf.float32来减少内存消耗。 - 考虑降低模型的复杂度,如减少网络层数或神经元数量。 - 对计算图进行优化,减少中间变量的使用,或者使用tf.function进行计算图静态编译。 - 升级到最新版本的TensorFlow,或者考虑使用其他更轻量级的深度学习框架。 以上是一些可能的解决方案,具体需要根据实际情况进行尝试和调整。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

月明Mo

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值