因为aws上下载模型很不稳定,所以想用下载工具把模型下载下来,然后拷贝到文件夹下面,下面代码解决:
net = model_zoo.get_model('mask_rcnn_resnet50_v1b_coco', pretrained=False, pretrained_base=False)
ctx = mx.gpu(0)
net.load_parameters('./mask_rcnn_resnet50_v1b_coco-a3527fdc.params', ctx=ctx,
ignore_extra=False, allow_missing=True)
im_fname = utils.download('https://github.com/dmlc/web-data/blob/master/' +
'gluoncv/detection/biking.jpg?raw=true',
path='biking.jpg')
x, orig_img = data.transforms.presets.rcnn.load_test(im_fname)
x = x.copyto(mx.gpu(0)) //这里很重要,否则会报错
如果你想改变类别数目:
def get_model(model, ctx, opt):
"""Model initialization."""
opt.use_pretrained = True
kwargs = {'ctx': ctx, 'pretrained': False, 'prefix':'mask_', 'classes': classes}
if model.startswith('resnet'):
kwargs['thumbnail'] = opt.use_thumbnail
elif model.startswith('vgg'):
kwargs['batch_norm'] = opt.batch_norm
prekwargs = {'ctx': ctx, 'prefix':'mask_', 'pretrained': True}
prenet = models.get_model(model, **prekwargs)
net = models.get_model(model, **kwargs)
#net.reset_class(classes)
net.features = prenet.features
net.output.initialize(mx.init.Xavier(), ctx)