import os
import urllib.request
import shutil
# URL from which to download the latest COCO trained weights
COCO_MODEL_URL = "https://github.com/matterport/Mask_RCNN/releases/download/v2.0/mask_rcnn_coco.h5"
def download_trained_weights(coco_model_path, verbose=1):
"""Download COCO trained weights from Releases.
coco_model_path: local path of COCO trained weights
"""
if verbose > 0:
print("Downloading pretrained model to " + coco_model_path + " ...")
with urllib.request.urlopen(COCO_MODEL_URL) as resp, open(coco_model_path, 'wb') as out:
shutil.copyfileobj(resp, out)
if verbose > 0:
print("... done downloading pretrained model!")
ROOT_DIR = os.path.abspath('./')
COCO_MODEL_PATH = os.path.join(ROOT_DIR, 'mask_rcnn_coco.h5')
if not os.path.exists(COCO_MODEL_PATH):
download_trained_weights(COCO_MODEL_PATH)
利用上述代码下载 github 上的 matterport 开源出来的 coco 预训练的 mask_rcnn 权重