mxnet测试网络速度

本文档详细介绍了如何使用MXNet进行模型计算FLOPs(浮点运算次数),包括数据预处理、模型加载和执行速度测量。通过解析symbol.json文件,作者展示了如何设置数据和标签形状,并演示了如何在一张图片上多次执行模型以测量性能。
摘要由CSDN通过智能技术生成

 

# -*- coding: utf-8 -*-
"""
File Name: calculate_flops.py
Author: liangdepeng
mail: liangdepeng@gmail.com
"""
import time

import cv2
import mxnet as mx
import argparse
import numpy as np
import json
import re


def parse_args():
    parser = argparse.ArgumentParser(description='')
    parser.add_argument('-ds', '--data_shapes', default=["data,1,3,112,112"], type=str, nargs='+',
                        help='data_shapes, format: arg_name,s1,s2,...,sn, example: data,1,3,224,224')
    parser.add_argument('-ls', '--label_shapes', default=["label,1,512"], type=str, nargs='+',
                        help='label_shapes, format: arg_name,s1,s2,...,sn, example: label,1,1,224,224')
    parser.add_argument('-s', '--symbol_path', type=str, default=r'softmax_label-symbol.json', help='')
    #
    return parser.parse_args()

def single_input(path):
    img = cv2.imread(path)
    # mxnet三通道输入是严格的RGB格式,而cv2.imread的默认是BGR格式,因此需要做一个转换
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = cv2.resize(img, (112, 112))
    img = img.transpose(2, 0, 1)

    # 添加一个第四维度并构建NDArray
    img = img[np.newaxis, :]
    array = mx.nd.array(img)
    return array


if __name__ == '__main__':

    from collections import namedtuple
    args = parse_args()
    sym = mx.sym.load(args.symbol_path)

    data_shapes = list()
    data_names = list()
    if args.data_shapes is not None and len(args.data_shapes) > 0:
        for shape in args.data_shapes:
            items = shape.replace('\'', '').replace('"', '').split(',')
            data_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
            data_names.append(items[0])

    label_shapes = None
    label_names = list()
    if args.label_shapes is not None and len(args.label_shapes) > 0:
        label_shapes = list()
        for shape in args.label_shapes:
            items = shape.replace('\'', '').replace('"', '').split(',')
            label_shapes.append((items[0], tuple([int(s) for s in items[1:]])))
            label_names.append(items[0])

    devs = [mx.cpu()]

    if len(label_names) == 0:
        label_names = None
    model = mx.mod.Module(context=devs, symbol=sym, data_names=data_names, label_names=None)
    model.bind(data_shapes=data_shapes,  for_training=False)

    model._params_dirty = True
    model.params_initialized = True


    model.save_checkpoint("0412", 1)

    time1 = time.time()


    # print("模型加载和重建时间:{0}".format(time1 - time0))

    Batch = namedtuple("batch", ['data'])

    img1_path = r'1054086.jpg'

    array1 = single_input(img1_path)
    start1 = time.time()
    for i in range(40):
        start = time.time()
        model.forward(Batch([array1]))
        vector1 = model.get_outputs()[0].asnumpy()
        print("time", time.time() - start, vector1.shape)
    print("total time", time.time() - start1)


 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

AI算法网奇

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值