训练模型时出现上述问题。
原因:
显卡内存爆了,出现OOM问题
解决办法:
1.允许gpu自增长
1. 1针对tf1.X
在代码的import语句下,添加
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
1.2 针对tf2.X
在代码的import语句下,添加
configuration = tf.compat.v1.ConfigProto()
configuration.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=configuration)
2. 调整batch_size
# batch_size = 4096
batch_size = 256
比如我将batch_size从4096调整到256,就不会报错了