maix_train本地训练出现“Failed to get convolution algorithm”的解决方法

使用sipeed提供的maix_train进行k210本地模型训练的时候,出现了以下问题

2022-01-29 21:31:49,805 - [ERROR]: failed: TrainFailReason.ERROR_INTERNAL, error occurred when train, error: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.

查找之前的log信息会发现以下错误

Function call stack:
train_function

同时,

2022-01-29 21:31:49.794489: E tensorflow/stream_executor/cuda/cuda_dnn.cc:328] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR

经过查找发现这个报错是由于显存不足引起,可以直接在文件中加入如下代码,让tensorflow自动进行显存分配

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = InteractiveSession(config=config)

于是寻找到出现报错的train_function
这个函数在Train类中的classify和detect部分都有各自的实现,找到__init__两个文件搜索train不难找到
train/classifier/__init__.pytrain/detector/__init__.py中的两个train()函数,在import语句后加入上述代码,再次进行训练,发现classify任务可以正常运行,但detect任务出现报错

2022-01-30 00:50:25,990 - [ERROR]: train error: The Session graph is empty. Add operations to the graph before calling run().
2022-01-30 00:50:25,991 - [ERROR]: failed: TrainFailReason.ERROR_INTERNAL, error occurred when train, error: The Session graph is empty. Add operations to the graph before calling run().

将之前在detector目录下__init__.py中添加的代码改为

from tensorflow.compat.v1 import ConfigProto
from tensorflow.compat.v1 import InteractiveSession
config = ConfigProto()
config.gpu_options.allow_growth = True
session = tf.compat.v1.Session(config=config)

也就是把InteractiveSession改成Session
再次运行detect任务,问题解决

推测出现bug的原因是新版本tensorflow对旧版本适配不完善
本例配置如下:
Nvidia RTX2060
cuda 10.1 + cudnn
tensorflow 2.3.0


补充

在StackOverFlow下面有老哥指出

Try to set:
os.environ[‘TF_FORCE_GPU_ALLOW_GROWTH’] = 'true’
solved my problem
my environment:
Cudnn 7.6.5
Tensorflow 2.4
Cuda Toolkit 10.1
RTX 2060

但是在笔者最早的尝试中未奏效,但是后续在Windows系统下复现时尝试使用之前的代码修复后仍无法使用,换用了上面这个代码,相关选段如下

import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'

成功解决了问题。
推测这个bug与Linux下的bug成因不同,是由于Windows下对显存访问的管理而出现,而不是tensorflow的兼容性问题

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值