MXNet官方文档教程(4):使用预训练好的模型

预在完全掌握MXNet前,如果你的应用方向是常见的一些问题,最快的发方法就是直接使用其他人已经训练好的模型。这样不但可以节省时间,也可以达到比较好的效果。这一篇就来讲讲如何使用已经训练好的模型。源连接:Predict with pre-trained models


使用预训练好的模型进行预测

本教程实现一个在全ImageNet数据集上预训练好的预测样例。ImageNet数据集包括超过一千万张图片和一万个类别。有关更加详细的阐述,参见。

我们首先载入模型:

import os,urllib

import mxnet as mx

def download(url,prefix=''):

    filename = prefix+url.split("/")[-1]

    if not os.path.exists(filename):

        urllib.urlretrieve(url,filename)

 

path='http://data.mxnet.io/models/imagenet-11k/'

download(path+'resnet-152/resnet-152-symbol.json','full-')

download(path+'resnet-152/resnet-152-0000.params','full-')

download(path+'synset.txt','full-')

 

with open('full-synset.txt','r') as f:

    synsets = [l.rstrip()for lin f]

 

sym, arg_params,aux_params = mx.model.load_checkpoint('full-resnet-152',0)

在编号为0的GPU上建立该模型:

mod= mx.mod.Module(symbol=sym, context=mx.gpu())
mod.bind(for_training=False, data_shapes=[('data', (1,3,224,224))])
mod.set_params(arg_params, aux_params)

接下来我们定义通过URL获取图片的函数和预测函数:

%matplotlib inline
import matplotlib
matplotlib.rc("savefig", dpi=100)
import matplotlib.pyplotasplt
import cv2
import numpy as np
from collections import namedtuple
Batch= namedtuple('Batch', ['data'])
 
def get_image(url, show=True):
    filename = url.split("/")[-1]
    urllib.urlretrieve(url, filename)
    img = cv2.imread(filename)
    if img isNone:
        print('failed to download '+ url)
    if show:
        plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
        plt.axis('off')
    return filename
 
def predict(filename, mod, synsets):
    img = cv2.cvtColor(cv2.imread(filename), cv2.COLOR_BGR2RGB)
    if img isNone:
        returnNone
    img = cv2.resize(img, (224,224))
    img = np.swapaxes(img,0,2)
    img = np.swapaxes(img,1,2)
    img = img[np.newaxis, :] 
    
    mod.forward(Batch([mx.nd.array(img)]))
    prob = mod.get_outputs()[0].asnumpy()
    prob = np.squeeze(prob)
 
    a = np.argsort(prob)[::-1]    
    for i in a[0:5]:
        print('probability=%f, class=%s'%(prob[i], synsets[i]))
   

我们就可以分类对图像进行分类并输出主要预测的类别。

url='http://writm.com/wp-content/uploads/2016/08/Cat-hd-wallpapers.jpg'
predict(get_image(url), mod, synsets)

输出:

probability=0.692314, class=n02122948 kitten, kitty
probability=0.043846, class=n01323155 kit
probability=0.030001, class=n01318894 pet
probability=0.029692, class=n02122878 tabby, queen
probability=0.026972, class=n01322221 baby

url='https://images-na.ssl-images-amazon.com/images/G/01/img15/pet-products/small-tiles/23695_pets_vertical_store_dogs_small_tile_8._CB312176604_.jpg'
predict(get_image(url), mod, synsets)

输出:

probability=0.569505, class=n02088364 beagle
probability=0.052795, class=n00452864 beagling
probability=0.039277, class=n02778669 ball
probability=0.017777, class=n02087122 hunting dog
probability=0.016321, class=n10611613 sleuth, sleuthhound

  • 4
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值