Mask Rcnn 源码复现和整个过程中遇到的问题总结
文章目录
我的运行环境为 python3.7、 tensorflow1.15.0、keras2.3.1
scikit-image0.19.3
1、下载源码和数据集
1.1 Mask Rcnn源码下载
Mask Rcnn官方源码和数据集下载`
需要下载Source code (zip)源码压缩包或者Source code (tar.gz)和balloon_dataset.zip:气球数据集、mask_rcnn_balloon.h5:训练好的气球数据集权重。还需要下载mask_rcnn_coco.h5,这个是基于mask rcnn训练好的coco数据集的权重,页面中的Mask Rcnn2.1运行时可以自动下载,但是速度较慢,可以向下翻到Mask Rcnn2.0中手动下载
目录结构如下,需手动创建images和logs文件夹
1.2 Mask Rcnn ballon.py文件修改
把–dataset、–weights、–logs改为自己当前数据的绝对路径、以及logs文件的绝对路径。
if __name__ == '__main__':
import argparse
# Parse command line arguments
parser = argparse.ArgumentParser(
description='Train Mask R-CNN to detect balloons.')
parser.add_argument("command",
metavar="<command>",
help="'train' or 'splash'")
parser.add_argument('--dataset', required=False,
metavar="/root/autodl-tmp/Mask_RCNN-2.1/balloon",
help='Directory of the Balloon dataset')
parser.add_argument('--weights', required=True,
metavar="/root/autodl-tmp/Mask_RCNN-2.1/mask_rcnn_balloon.h5",
help="Path to weights .h5 file or 'coco'")
parser.add_argument('--logs', required=False,
default=DEFAULT_LOGS_DIR,
metavar="/root/autodl-tmp/Mask_RCNN-2.1/logs",
help='Logs and checkpoints directory (default=logs/)')
parser.add_argument('--image', required=False,
metavar="/root/autodl-tmp/Mask_RCNN-2.1/balloon/balloon/val/14898532020_ba6199dd22_k.jpg",
help='Image to apply the color splash effect on')
parser.add_argument('--video', required=False,
metavar="path or URL to video",
help='Video to apply the color splash effect on')
args = parser.parse_args()
1.3 运行 ballon.py文件
在运行命令中添加必要的参数
python balloon.py train --dataset=/root/autodl-tmp/Mask_RCNN-2.1/balloon --weights=/root/autodl-tmp/Mask_RCNN-2.1/mask_rcnn_balloon.h5
dataset为ballon数据集文件夹地址, weights为mask_rcnn_balloon.h5文件地址
等待epoch结束
测试检测结果
2、运行时与到的问题
问题1:No module named ‘skimage’
默认安装的版本与其他库文件版本不匹配,重新根据环境安装0.19.3
pip install scikit-image==0.19.3
问题2:AttributeError: module ‘scipy.misc’ has no attribute ‘imresize’
代码中使用了scipy.misc模块的imresize函数,但是在最新的SciPy版本中,imresize已被弃用并从scipy.misc模块中移除。这导致了AttributeError: module ‘scipy.misc’ has no attribute 'imresize’错误
解决办法为找到代码中使用scipy.misc.imresize的地方。
将它替换为skimage.transform.resize。
例如
image = scipy.misc.imresize(
image, (round(h * scale), round(w * scale)))
替换为
from skimage.transform import resize
image = resize(image, (round(h * scale), round(w * scale)), mode='constant', anti_aliasing=True)
问题3:IndexError: index 1024 is out of bounds for axis 1 with size 1024
这个错误是由于尝试在 mask 中使用索引 1024,而 mask 的宽度是 1024,索引应该从 0 到 1023。要解决这个问题,可以确保索引不超出边界。
在load_mask函数中找到这句代码
mask[rr, cc, i] = 1
添加以下代码
# 添加检查,确保索引不超出边界
rr = np.clip(rr, 0, mask.shape[0] - 1)
cc = np.clip(cc, 0, mask.shape[1] - 1)
mask[rr, cc, i] = 1