Session创建
创建Session的需要根据option获取session factory, 再由session factory创建session。
Status NewSession(const SessionOptions& options, Session** out_session) {
SessionFactory* factory;
Status s = SessionFactory::GetFactory(options, &factory);//获取factory
if (!s.ok()) {
*out_session = nullptr;
LOG(ERROR) << "Failed to get session factory: " << s;
return s;
}
session_created->GetCell()->Set(true);
s = factory->NewSession(options, out_session); //创建session
if (!s.ok()) {
*out_session = nullptr;
LOG(ERROR) << "Failed to create session: " << s;
}
return s;
}
SessionFactory
SessionFactory在tensorflow/core/common_runtime/session_factory.h。与一般Factory实现类似,
用了一个全局变量map保存了factory, key是runtime_type, 同时提供了一个静态方法来注册factory.
Factory有两个对外接口:NewSession和Reset
static mutex* get_session_factory_lock() {
static mutex session_factory_lock(LINKER_INITIALIZED);
return &session_factory_lock;
}
typedef std::unordered_map<string, SessionFactory*> SessionFactories;
SessionFactories* session_factories() {
static SessionFactories* factories = new SessionFactories; //map保存session
return factories;
}
} // namespace
void SessionFactory::Register(const string& runtime_type,
SessionFactory* factory) {
mutex_lock l(*get_session_factory_lock());
if (!session_factories()->insert({runtime_type, factory}).second) { //注册session
LOG(ERROR) << "Two session factories are being registered "
<< "under" << runtime_type;
}
}
Status SessionFactory::GetFactory(const SessionOptions& options,
SessionFactory** out_factory) {
mutex_lock l(*get_session_factory_lock()); // could use reader lock
std::vector<std::pair<string, SessionFactory*>> candidate_factories;
for (const auto& session_factory : *session_factories()) {
if (session_factory.second->AcceptsOptions(options)) {
candidate_factories.push_back(session_factory);
} else {
VLOG(2) << "SessionFactory type " << session_factory.first
<< " does not accept target: " << options.target;
}
}
if (candidate_factories.size() == 1) {
*out_factory = candidate_factories[0].second; //找到且唯一
return Status::OK();
} else if (candidate_factories.size() > 1) {
//报错
} else {
//报错
}
}
SessionFactory接口
这里Factory的接口,也就是说Factory有多种实现。这个类是管理Factory的。不是管理Session的。tensorflow session也是支持定制的,也就是说我们可以实现自己的session。
class Session;
struct SessionOptions;
class SessionFactory {
public:
//三个虚函数,是SessionFactory的接口
virtual Status NewSession(const SessionOptions& options,
Session** out_session) = 0;
virtual bool AcceptsOptions(const SessionOptions& options) = 0;
virtual Status Reset(const SessionOptions& options,
const std::vector<string>& containers) {
return errors::Unimplemented("Reset()");
}
virtual ~SessionFactory() {}
//静态方法,创建注册和获取factory
static void Register(const string& runtime_type, SessionFactory* factory);
static Status GetFactory(const SessionOptions& options,
SessionFactory** out_factory);
};
默认实现的Factory
- DIRECT_SESSION tensorflow/core/common_runtime/direct_session.cc 也就是单机模式下,本地运行的Session.
-
GRPC_SESSION tensorflow/core/distribute_runtime/rpc/grpc_session.cc 分布式模式下基于rpc的session。