文章目录
1. 模型加载——sym.load()实现
1.1 加载symbol
1.1.1 python端sym.load——>调用ctype MXSymbolCreateFromFile
symbol = sym.load('%s-symbol.json' % prefix) # %s是模型存放的绝对地址
def load(fname):
"""Loads symbol from a JSON file.
You can also use pickle to do the job if you only work on python.
The advantage of load/save is the file is language agnostic.
This means the file saved using save can be loaded by other language binding of mxnet.
You also get the benefit being able to directly load/save from cloud storage(S3, HDFS).
Parameters
----------
fname : str
The name of the file, examples:
- `s3://my-bucket/path/my-s3-symbol`
- `hdfs://my-bucket/path/my-hdfs-symbol`
- `/path-to/my-local-symbol`
Returns
-------
sym : Symbol
The loaded symbol.
See Also
--------
Symbol.save : Used to save symbol into file.
"""
if not isinstance(fname, string_types):
raise TypeError('fname need to be string')
handle = SymbolHandle()
check_call(_LIB.MXSymbolCreateFromFile(c_str(fname), ctypes.byref(handle))) #调用ctype接口MXSymbolCreateFromFile
return Symbol(handle)
1.1.2 dmlc::Stream::Create 生成文件流用于读取.json文件
int MXSymbolCreateFromFile(const char *fname, SymbolHandle *out) {
nnvm::Symbol *s = new nnvm::Symbol();
API_BEGIN();
std::unique_ptr<dmlc::Stream> fi(dmlc::Stream::Create(fname, "r"));
dmlc::istream is(fi.get());
nnvm::Graph g;
g.attrs["json"] = std::make_shared<nnvm::any>(
std::string(std::istreambuf_iterator<char>(is), std::istreambuf_iterator<char>()));
s->outputs = nnvm::ApplyPass(g, "LoadLegacyJSON").outputs;
*out = s;
is.set_stream(nullptr);
API_END_HANDLE_ERROR(delete s);
}
1.1.3 将char*转换为URI结构体
Stream *Stream::Create(const char *uri,
const char * const flag,
bool try_create) {
io::URI path(uri); // 将char*转换为URI结构体
return io::FileSystem::
GetInstance(path)->Open(path, flag, try_create);
}
1.1.4 URI结构体的内容
需要注意的十几个数据成员:
protocol:协议,用于判断模型地址是否是本地地址file://还是网络地址https://
host: 模型地址
struct URI {
/*! \brief protocol */
std::string protocol;
/*!
* \brief host name, namenode for HDFS, bucket name for s3
*/
std::string host;
/*! \brief name of the path */
std::string name;
/*! \brief enable default constructor */
URI(void) {}
/*!
* \brief construct from URI string
*/
explicit URI(const char *uri) { // 解析json文件绝对地址 uri如:W://Clinical//LangFang//Second_NMS//EN//Unet-symbol.json
const char *p = std::strstr(uri, "://"); // p=://Clinical//LangFang//Second_NMS//EN//Unet-symbol.json
if (p == NULL) {
name = uri;
} else {
protocol = std::string(uri, p - uri + 3); // protocol=W://
uri = p + 3; // uri=Clinical//LangFang//Second_NMS//EN//Unet-symbol.json
p = std::strchr(uri, '/'); // p=//LangFang//Second_NMS//EN//Unet-symbol.json
if (p == NULL) {
host = uri; name = '/';
} else {
host = std::string(uri, p - uri); // host=Clinical
name = p; // name=//LangFang//Second_NMS//EN//Unet-symbol.json
}
}
}
/*! \brief string representation */
inline std::string str(void) const {
return protocol + host + name;
}
};
1.1.5 GetInstance(path)——根据前情的操作根据参数protocol创建相应文件流(stream)的实例
FileSystem *FileSystem::GetInstance(const URI &path) {
if (path.protocol == "file://" || path.protocol.length() == 0) {
return LocalFileSystem::GetInstance(); // 解析后是本地故调用LocalFileSystem
}
if (path.protocol == "hdfs://" || path.protocol == "viewfs://") {
#if DMLC_USE_HDFS
if (path.host.length() == 0) {
return HDFSFileSystem::GetInstance("default");
} else if (path.protocol == "viewfs://") {
char* defaultFS = nullptr;
hdfsConfGetStr("fs.defaultFS", &defaultFS);
if (path.host.length() != 0) {
CHECK("viewfs://" + path.host == defaultFS)
<< "viewfs is only supported as a fs.defaultFS.";
}
return HDFSFileSystem::GetInstance("default");
} else {
return HDFSFileSystem::GetInstance(path.host);
}
#else
LOG(FATAL) << "Please compile with DMLC_USE_HDFS=1 to use hdfs";
#endif
}
if (path.protocol == "s3://" || path.protocol == "http://" || path.protocol == "https://") {
#if DMLC_USE_S3
return S3FileSystem::GetInstance();
#else
LOG(FATAL) << "Please compile with DMLC_USE_S3=1 to use S3";
#endif
}
if (path.protocol == "azure://") {
#if DMLC_USE_AZURE
return AzureFileSystem::GetInstance();
#else
LOG(FATAL) << "Please compile with DMLC_USE_AZURE=1 to use Azure";
#endif
}
LOG(FATAL) << "unknown filesystem protocol " + path.protocol;
return NULL;
}
1.1.6 LocalFileSystem
inline static LocalFileSystem *GetInstance(void) {
static LocalFileSystem instance;
return &instance;
}
// 开启文件流stream
SeekStream *LocalFileSystem::Open(const URI &path,
const char* const mode,
bool allow_null) {
bool use_stdio = false;
FILE *fp = NULL;
#ifdef _WIN32
const int fname_length = MultiByteToWideChar(CP_UTF8, 0, path.name.c_str(), -1, nullptr, 0);
CHECK(fname_length > 0) << " LocalFileSystem::Open \"" << path.str()
<< "\": " << "Invalid character sequence.";
std::wstring fname(fname_length, 0);
MultiByteToWideChar(CP_UTF8, 0, path.name.c_str(), -1, &fname[0], fname_length);
const int mode_length = MultiByteToWideChar(CP_UTF8, 0, mode, -1, nullptr, 0);
std::wstring wmode(mode_length, 0);
MultiByteToWideChar(CP_UTF8, 0, mode, -1, &wmode[0], mode_length);
using namespace std;
#ifndef DMLC_DISABLE_STDIN
if (!wcscmp(fname.c_str(), L"stdin")) {
use_stdio = true; fp = stdin;
}
if (!wcscmp(fname.c_str(), L"stdout")) {
use_stdio = true; fp = stdout;
}
#endif // DMLC_DISABLE_STDIN
if (!wcsncmp(fname.c_str(), L"file://", 7)) { fname = fname.substr(7); }
if (!use_stdio) {
std::wstring flag(wmode.c_str());
if (flag == L"w") flag = L"wb";
if (flag == L"r") flag = L"rb";
#if DMLC_USE_FOPEN64
fp = _wfopen(fname.c_str(), flag.c_str());
#else // DMLC_USE_FOPEN64
fp = fopen(fname, flag.c_str());
#endif // DMLC_USE_FOPEN64
}
#else // _WIN32
const char *fname = path.name.c_str();
using namespace std;
#ifndef DMLC_DISABLE_STDIN
if (!strcmp(fname, "stdin")) {
use_stdio = true; fp = stdin;
}
if (!strcmp(fname, "stdout")) {
use_stdio = true; fp = stdout;
}
#endif // DMLC_DISABLE_STDIN
if (!strncmp(fname, "file://", 7)) fname += 7;
if (!use_stdio) {
std::string flag = mode;
if (flag == "w") flag = "wb";
if (flag == "r") flag = "rb";
#if DMLC_USE_FOPEN64
fp = fopen64(fname, flag.c_str());
#else // DMLC_USE_FOPEN64
fp = fopen(fname, flag.c_str());
#endif // DMLC_USE_FOPEN64
}
#endif // _WIN32
if (fp != NULL) {
return new FileStream(fp, use_stdio);
} else {
CHECK(allow_null) << " LocalFileSystem::Open \"" << path.str() << "\": " << strerror(errno);
return NULL;
}
}