代码:
import mxnet as mx
def get_output_symbol(symbol):
"""
Parameters
----------
symbol: Symbol
Symbol to be visualized.
"""
import json
from mxnet.symbol.symbol import Symbol
if not isinstance(symbol, Symbol):
raise TypeError("symbol must be Symbol")
conf = json.loads(symbol.tojson())
nodes = conf["nodes"]
heads = set(conf["heads"][0])
symbols = []
for i, node in enumerate(nodes):
op = node["op"]
if op == "null" and i > 0:
continue
if op != "null" or i in heads:
symbols.append(node['name'])
return symbols
def debug_model(model):
# prepare data 准备输入数据
input_blob=mx.nd.zeros(shape=(1,3,112,112),ctx=mx.cpu()