tensorflow 之session factory

Session创建

创建Session的需要根据option获取session factory, 再由session factory创建session。


Status NewSession(const SessionOptions& options, Session** out_session) {
  SessionFactory* factory;
  Status s = SessionFactory::GetFactory(options, &factory);//获取factory
  if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << "Failed to get session factory: " << s;
    return s;
  }

  session_created->GetCell()->Set(true);
  s = factory->NewSession(options, out_session);  //创建session
  if (!s.ok()) {
    *out_session = nullptr;
    LOG(ERROR) << "Failed to create session: " << s;
  }
  return s;
}

SessionFactory

SessionFactory在tensorflow/core/common_runtime/session_factory.h。与一般Factory实现类似,

用了一个全局变量map保存了factory, key是runtime_type, 同时提供了一个静态方法来注册factory.

Factory有两个对外接口:NewSession和Reset

static mutex* get_session_factory_lock() {
  static mutex session_factory_lock(LINKER_INITIALIZED);
  return &session_factory_lock;
}

typedef std::unordered_map<string, SessionFactory*> SessionFactories;

SessionFactories* session_factories() {
  static SessionFactories* factories = new SessionFactories;  //map保存session
  return factories;
}

}  // namespace

void SessionFactory::Register(const string& runtime_type,
                              SessionFactory* factory) {
  mutex_lock l(*get_session_factory_lock());
  if (!session_factories()->insert({runtime_type, factory}).second) {  //注册session
    LOG(ERROR) << "Two session factories are being registered "
               << "under" << runtime_type;
  }
}



Status SessionFactory::GetFactory(const SessionOptions& options,
                                  SessionFactory** out_factory) {
  mutex_lock l(*get_session_factory_lock());  // could use reader lock

  std::vector<std::pair<string, SessionFactory*>> candidate_factories;
  for (const auto& session_factory : *session_factories()) {
    if (session_factory.second->AcceptsOptions(options)) {
      candidate_factories.push_back(session_factory);
    } else {
      VLOG(2) << "SessionFactory type " << session_factory.first
              << " does not accept target: " << options.target;
    }
  }

  if (candidate_factories.size() == 1) {
    *out_factory = candidate_factories[0].second; //找到且唯一
    return Status::OK();
  } else if (candidate_factories.size() > 1) {
    //报错
    
  } else {
    //报错
   
  }
}

SessionFactory接口

这里Factory的接口,也就是说Factory有多种实现。这个类是管理Factory的。不是管理Session的。tensorflow session也是支持定制的,也就是说我们可以实现自己的session。


class Session;
struct SessionOptions;

class SessionFactory {
 public:
  //三个虚函数,是SessionFactory的接口
  virtual Status NewSession(const SessionOptions& options,
                            Session** out_session) = 0;

  virtual bool AcceptsOptions(const SessionOptions& options) = 0;

  virtual Status Reset(const SessionOptions& options,
                       const std::vector<string>& containers) {
    return errors::Unimplemented("Reset()");
  }

  virtual ~SessionFactory() {}
  //静态方法,创建注册和获取factory
  static void Register(const string& runtime_type, SessionFactory* factory);
  static Status GetFactory(const SessionOptions& options,
                           SessionFactory** out_factory);
};

默认实现的Factory

  1. DIRECT_SESSION  tensorflow/core/common_runtime/direct_session.cc 也就是单机模式下,本地运行的Session. 
  2. GRPC_SESSION tensorflow/core/distribute_runtime/rpc/grpc_session.cc 分布式模式下基于rpc的session。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值