mxnet symbol 打印模型所有中间输出

代码:

import mxnet as mx


def get_output_symbol(symbol):
  """
  Parameters
  ----------
  symbol: Symbol
      Symbol to be visualized.

  """
  import json
  from mxnet.symbol.symbol import Symbol
  if not isinstance(symbol, Symbol):
    raise TypeError("symbol must be Symbol")

  conf = json.loads(symbol.tojson())
  nodes = conf["nodes"]
  heads = set(conf["heads"][0])

  symbols = []
  for i, node in enumerate(nodes):
    op = node["op"]
    if op == "null" and i > 0:
      continue
    if op != "null" or i in heads:
      symbols.append(node['name'])
  return symbols

def debug_model(model):
  # prepare data 准备输入数据
  input_blob=mx.nd.zeros(shape=(1,3,112,112),ctx=mx.cpu()
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值