mxnet模型加载--源码分析--后端C++做了什么

1. 模型加载——sym.load()实现

1.1 加载symbol

1.1.1 python端sym.load——>调用ctype MXSymbolCreateFromFile

symbol = sym.load('%s-symbol.json' % prefix)  # %s是模型存放的绝对地址
def load(fname):
    """Loads symbol from a JSON file.

    You can also use pickle to do the job if you only work on python.
    The advantage of load/save is the file is language agnostic.
    This means the file saved using save can be loaded by other language binding of mxnet.
    You also get the benefit being able to directly load/save from cloud storage(S3, HDFS).

    Parameters
    ----------
    fname : str
        The name of the file, examples:

        - `s3://my-bucket/path/my-s3-symbol`
        - `hdfs://my-bucket/path/my-hdfs-symbol`
        - `/path-to/my-local-symbol`

    Returns
    -------
    sym : Symbol
        The loaded symbol.

    See Also
    --------
    Symbol.save : Used to save symbol into file.
    """
    if not isinstance(fname, string_types):
        raise TypeError('fname need to be string')
    handle = SymbolHandle()
    check_call(_LIB.MXSymbolCreateFromFile(c_str(fname), ctypes.byref(handle)))  #调用ctype接口MXSymbolCreateFromFile
    return Symbol(handle)

1.1.2 dmlc::Stream::Create 生成文件流用于读取.json文件

int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) {
  nnvm::Symbol *s = new nnvm::Symbol();
  API_BEGIN();
  std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
  dmlc::istream is(fi.get());
  nnvm::Graph g;
  g.attrs["json"] = std::make_shared<nnvm::any>(
    std::string(std::istreambuf_iterator<char>(is), std::istreambuf_iterator<char>()));
  s->outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
  *out = s;
  is.set_stream(nullptr);
  API_END_HANDLE_ERROR(delete s);
}

1.1.3 将char*转换为URI结构体

Stream *Stream::Create(const char *uri,
                       const char * const flag,
                       bool try_create) {
  io::URI path(uri);  // 将char*转换为URI结构体
  return io::FileSystem::
      GetInstance(path)->Open(path, flag, try_create);
}

1.1.4 URI结构体的内容

需要注意的十几个数据成员:

protocol:协议,用于判断模型地址是否是本地地址file://还是网络地址https://

host: 模型地址

struct URI {
  /*! \brief protocol */
  std::string protocol;
  /*!
   * \brief host name, namenode for HDFS, bucket name for s3
   */
  std::string host;
  /*! \brief name of the path */
  std::string name;
  /*! \brief enable default constructor */
  URI(void) {}
  /*!
   * \brief construct from URI string
   */
  explicit URI(const char *uri) { // 解析json文件绝对地址 uri如:W://Clinical//LangFang//Second_NMS//EN//Unet-symbol.json
    const char *p = std::strstr(uri, "://");  // p=://Clinical//LangFang//Second_NMS//EN//Unet-symbol.json
    if (p == NULL) {
      name = uri;
    } else {
      protocol = std::string(uri, p - uri + 3); // protocol=W://
      uri = p + 3; // uri=Clinical//LangFang//Second_NMS//EN//Unet-symbol.json
      p = std::strchr(uri, '/'); // p=//LangFang//Second_NMS//EN//Unet-symbol.json
      if (p == NULL) {
        host = uri; name = '/';
      } else {
        host = std::string(uri, p - uri); // host=Clinical
        name = p; // name=//LangFang//Second_NMS//EN//Unet-symbol.json
      }
    }
  }
  /*! \brief string representation */
  inline std::string str(void) const {
    return protocol + host + name;
  }
};

1.1.5 GetInstance(path)——根据前情的操作根据参数protocol创建相应文件流(stream)的实例

FileSystem *FileSystem::GetInstance(const URI &path) {
  if (path.protocol == "file://" || path.protocol.length() == 0) {
    return LocalFileSystem::GetInstance(); // 解析后是本地故调用LocalFileSystem
  }
  if (path.protocol == "hdfs://" || path.protocol == "viewfs://") {
#if DMLC_USE_HDFS
    if (path.host.length() == 0) {
      return HDFSFileSystem::GetInstance("default");
    } else if (path.protocol == "viewfs://") {
      char* defaultFS = nullptr;
      hdfsConfGetStr("fs.defaultFS", &defaultFS);
      if (path.host.length() != 0) {
        CHECK("viewfs://" + path.host == defaultFS)
            << "viewfs is only supported as a fs.defaultFS.";
      }
      return HDFSFileSystem::GetInstance("default");
    } else {
      return HDFSFileSystem::GetInstance(path.host);
    }
#else
    LOG(FATAL) << "Please compile with DMLC_USE_HDFS=1 to use hdfs";
#endif
  }
  if (path.protocol == "s3://" || path.protocol == "http://" || path.protocol == "https://") {
#if DMLC_USE_S3
    return S3FileSystem::GetInstance();
#else
    LOG(FATAL) << "Please compile with DMLC_USE_S3=1 to use S3";
#endif
  }

  if (path.protocol == "azure://") {
#if DMLC_USE_AZURE
    return AzureFileSystem::GetInstance();
#else
    LOG(FATAL) << "Please compile with DMLC_USE_AZURE=1 to use Azure";
#endif
  }

  LOG(FATAL) << "unknown filesystem protocol " + path.protocol;
  return NULL;
}

1.1.6 LocalFileSystem

inline static LocalFileSystem *GetInstance(void) {
    static LocalFileSystem instance;
    return &instance;
}

// 开启文件流stream
SeekStream *LocalFileSystem::Open(const URI &path,
                                  const char* const mode,
                                  bool allow_null) {
  bool use_stdio = false;
  FILE *fp = NULL;
#ifdef _WIN32
  const int fname_length = MultiByteToWideChar(CP_UTF8, 0, path.name.c_str(), -1, nullptr, 0);
  CHECK(fname_length > 0) << " LocalFileSystem::Open \"" << path.str()
                          << "\": " << "Invalid character sequence.";
  std::wstring fname(fname_length, 0);
  MultiByteToWideChar(CP_UTF8, 0, path.name.c_str(), -1, &fname[0], fname_length);

  const int mode_length = MultiByteToWideChar(CP_UTF8, 0, mode, -1, nullptr, 0);
  std::wstring wmode(mode_length, 0);
  MultiByteToWideChar(CP_UTF8, 0, mode, -1, &wmode[0], mode_length);

  using namespace std;
#ifndef DMLC_DISABLE_STDIN
  if (!wcscmp(fname.c_str(), L"stdin")) {
    use_stdio = true; fp = stdin;
  }
  if (!wcscmp(fname.c_str(), L"stdout")) {
    use_stdio = true; fp = stdout;
  }
#endif  // DMLC_DISABLE_STDIN
  if (!wcsncmp(fname.c_str(), L"file://", 7)) { fname = fname.substr(7); }
  if (!use_stdio) {
    std::wstring flag(wmode.c_str());
    if (flag == L"w") flag = L"wb";
    if (flag == L"r") flag = L"rb";
#if DMLC_USE_FOPEN64
    fp = _wfopen(fname.c_str(), flag.c_str());
#else  // DMLC_USE_FOPEN64
    fp = fopen(fname, flag.c_str());
#endif  // DMLC_USE_FOPEN64
  }
#else  // _WIN32
  const char *fname = path.name.c_str();
  using namespace std;
#ifndef DMLC_DISABLE_STDIN
  if (!strcmp(fname, "stdin")) {
    use_stdio = true; fp = stdin;
  }
  if (!strcmp(fname, "stdout")) {
    use_stdio = true; fp = stdout;
  }
#endif  // DMLC_DISABLE_STDIN
  if (!strncmp(fname, "file://", 7)) fname += 7;
  if (!use_stdio) {
    std::string flag = mode;
    if (flag == "w") flag = "wb";
    if (flag == "r") flag = "rb";
#if DMLC_USE_FOPEN64
    fp = fopen64(fname, flag.c_str());
#else  // DMLC_USE_FOPEN64
    fp = fopen(fname, flag.c_str());
#endif  // DMLC_USE_FOPEN64
  }
#endif  // _WIN32
  if (fp != NULL) {
    return new FileStream(fp, use_stdio);
  } else {
    CHECK(allow_null) << " LocalFileSystem::Open \"" << path.str() << "\": " << strerror(errno);
    return NULL;
  }
}
  

未完待续

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值