在TensorFlow
中,用户是通过运行图来进行模型训练的,而启动图的第一步就是创建一个session
对象。在日常编写Python
代码时,有的直接通过编写sess=tf.Session()
来创建session
,也有的在分布式TensorFlow
中通过ChiefSessionCreator
和WorkerSessionCreator
的create_session()
来创建session
。这里简单说明下,create_session()
实质上对tf.Session()
的封装,只是里面添加了很多其他的功能,后期会对SessionCreator
进行详细的介绍。鉴于前期读者反馈说看不大懂,所以今天,谱哥主要是想带大家来了解下sess=tf.Session()
背后的实现原理,并介绍allocator
在session
创建时在哪里有体现。
TensorFlow
系统分为前端系统和后端系统,前端系统提供编程模型,重点负责图的构造,目前主流编程语言是
Python
;后端系统主要负责图的执行,用C++语言来进行编写;
Swig
作为前端系统和后端系统建立连接的桥梁,使得前端
Python
创建
session
能够触发后端C++进行
session
创建。因此,接下来,将按照前端
Python
层、
Swig
以及后端C++层三个方面来详细说明
sess=tf.Session()
底部实现原理。
1. 前端:Python层在前端系统中,session相关类的继承关系如下所示:
从中可知,
session
分为两种,普通
Session
和交互式
InteractiveSession
。后者自带with上下文管理器,并且在初始化的时候将自身作为默认的
session
,因此适合在
Python
交互式环境下使用。普通
Session
和交互式
InteractiveSession
都继承
BaseSession
,
BaseSession
继承
SessionInterface
。当用户层执行
sess=tf.Session()
时,会依次调用
SessionInterface
、
BaseSession
和
Session
的初始化函数。在
BaseSession
的初始化函数中有以下几行代码:
from tensorflow.python import pywrap_tensorflow as tf_sessionclass BaseSession(SessionInterface): def __init__(self, target='', graph=None, config=None): ...... self._session = None opts = tf_session.TF_NewSessionOptions(target=self._target, config=config) try: # pylint: disable=protected-access self._session = tf_session.TF_NewSession(self._graph._c_graph, opts) # pylint: enable=protected-access finally: tf_session.TF_DeleteSessionOptions(opts)
由上可知,
session是在BaseSession初始化的时候执行tf_session.TF_NewSession()来创建,传入的参数
opts
通过
tf_session.TF_NewSessionOptions
创建,是一个
SessionOptions
结构体,完成
env
、
target
和
config
的简单封装。
target
参数主要用来判断是创建
DirectSession
还是
GrpcSession
。
struct SessionOptions { Env* env; string target; ConfigProto config; SessionOptions();};
tf_session
实质指的是
pywrap_tensorflow.py
模块,该模块内部导入了
pywrap_tensorflow_internal.py
模块。而
pywrap_tensorflow_internal.py
是在系统启动
Swig
的时候通过
tensorflow.i
自动生成的适配文件,因此要想知道
tf_session.TF_NewSession()
内部到底干了啥,需要了解
TensorfFlow
是怎样使用
Swig
。
2. Swig包装器在
TensorFlow
启动
Swig
的时候,会通过
tensorflow.i
生成两个文件: (1) pywrap_tensorflow_internal.py:对接前端python接口调用;
(2) pywrap_tensorflow_internal.cc:对接后端C++接口调用。
pywrap_tensorflow_internal.py
模块在
pywrap_tensorflow.py
模块中被导入的时候,会自动加载
_pywrap_tensorflow_internal.so
的动态链接库,该库包含了整个
TensorFlow
运行时的所有符号。因此,在
pywrap_tensorflow_internal.py
模块中,可以通过
_pywrap_tensorflow_internal
转发,实现
Python
接口到
_pywrap_tensorflow_internal.so
的函数调用。
def TF_NewSession(graph, opts): return _pywrap_tensorflow_internal.TF_NewSession(graph, opts)TF_NewSession = _pywrap_tensorflow_internal.TF_NewSession
pywrap_tensorflow_internal.cc
注册了一个函数符号表,实现Python函数到C函数名的映射。
{ (char *)"TF_NewSession", _wrap_TF_NewSession, METH_VARARGS, NULL},
而
_wrap_TF_NewSession
将调用
c_api.h
对其开放的API接口:
TF_NewSession
,从而进入系统后端C++层。
3. 后端:C++层
3.1 TF_NewSession在
c_api.cc
文件中,
TF_NewSession()
相关定义如下所示:
TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt, TF_Status* status) { Session* session; //创建session status->status = NewSession(opt->options, &session); if (status->status.ok()) { //创建TF_Session对象,实现session和graph的绑定 TF_Session* new_session = new TF_Session(session, graph); if (graph != nullptr) { mutex_lock l(graph->mu); graph->sessions[new_session] = ""; } //返回TF_Session对象 return new_session; } else { DCHECK_EQ(nullptr, session); return nullptr; }}
从上可以看出,
TF_NewSession()
干了两件事儿:(1)创建了
session
对象;(2)创建
TF_Session
对象,实现
session
和
graph
的绑定,并
最终返回TF_Session,而不是session。
TF_Session
的相关定义如下:
struct TF_Session { TF_Session(tensorflow::Session* s, TF_Graph* g); tensorflow::Session* session; TF_Graph* const graph; int last_num_graph_nodes; // If true, TF_SessionRun and similar methods will call // ExtendSessionGraphHelper before running the graph std::atomic extend_before_run;};
3.2 NewSession
那么NewSession()
又是怎么定义的呢?于是追踪到了
session.cc
文件,相关代码如下所示:
Status NewSession(const SessionOptions& options, Session** out_session) { SessionFactory* factory; const Status s = SessionFactory::GetFactory(options, &factory); if (!s.ok()) { *out_session = nullptr; LOG(ERROR) << s; return s; } *out_session = factory->NewSession(options); if (!*out_session) { return errors::Internal("Failed to create session."); } return Status::OK();}
可知,
NewSession
采用了工厂模式,先根据
options
去找出符合要求的工厂
factory
,然后在指定的工厂里创建
Session
。
3.3 SessionFactory::GetFactory 可能大家又会问又是如何去根据
options
找出对应的
factory
?于是,追踪到了
session_factory.cc
文件,在剖析
GetFactory()
之前,需要先来理解
session_factories()
的相关概念,如下所示:
typedef std::unordered_map SessionFactories;SessionFactories* session_factories() { static SessionFactories* factories = new SessionFactories; return factories;}
可知,
session_factories()
其实是创建了一个静态的
SessionFactories
,而这个
SessionFactories
是一个
unordered_map
,实现
string
类型的
runtime_type
到
SessionFactory
指针的映射。那么既然是
unordered_map
,就必然涉及到
key
和
value
的存储,在这里是通过
SessionFactory::Register()
来进行注册的。
SessionFactory::Register()
的相关定义如下所示:
void SessionFactory::Register(const string& runtime_type, SessionFactory* factory) { mutex_lock l(*get_session_factory_lock()); if (!session_factories()->insert({runtime_type, factory}).second) { ...... }}
由上可知,
SessionFactory::Register()
的本质就是将
runtime_type
和
factory
指针插入到
unordered_map
中。紧接着,我追踪到
direct_session.cc
和
grpc_session.cc
文件,发现了相关
session
的注册。
DirectSessionFactory
的注册如下所示:
class DirectSessionRegistrar { public: DirectSessionRegistrar() { SessionFactory::Register("DIRECT_SESSION", new DirectSessionFactory()); }};static DirectSessionRegistrar registrar;
GrpcSessionFactory
的注册如下所示:
class GrpcSessionRegistrar { public: GrpcSessionRegistrar() { SessionFactory::Register("GRPC_SESSION", new GrpcSessionFactory()); }};static GrpcSessionRegistrar registrar;
现在来看看
SessionFactory::GetFactory()
核心代码,如下所示:
Status SessionFactory::GetFactory(const SessionOptions& options, SessionFactory** out_factory) { std::vector<:pair sessionfactory>> candidate_factories; for (const auto& session_factory : *session_factories()) { if (session_factory.second->AcceptsOptions(options)) { candidate_factories.push_back(session_factory); } } if (candidate_factories.size() == 1) { *out_factory = candidate_factories[0].second; return Status::OK(); } else if (candidate_factories.size() > 1) { //报错 }}
从
GetFactory()
代码中可知,其本质就是遍历
session_factories()
中的
unordered_map
,然后通过
unordered_map
中的
SessionFactory
(如
DirectSessionFactory
和
GrpcSessionFactory
)是否
AcceptsOptions
来进行选择,并且硬性要求有且仅有一个
factory
满足要求,否则报错。不同
SessionFactory
的
AcceptsOptions()
的定义如下:
DirectSessionFactory:
bool AcceptsOptions(const SessionOptions& options) override { return options.target.empty(); }
从上可知,若
options.target
为空,则应选择
DirectSessionFactory
,用于本地训练。
GrpcSessionFactory:
const char* const kSchemePrefix = "grpc://";bool AcceptsOptions(const SessionOptions& options) override { return str_util::StartsWith(options.target, kSchemePrefix); }
从上可知,若
options.target
是以
grpc://
开头的,则应选择
GrpcSessionFactory
,用于分布式
TensorFlow
。
3.4 factory->NewSession延续3.2节所讲,根据3.3节
SessionFactory::GetFactory
返回值的
SessionFactory
类型,去调用对应
SessionFactory
的
NewSession
接口。这里以
DirectSessionFactory::NewSession
为例,代码如下:
Session* NewSession(const SessionOptions& options) override { // Must do this before the CPU allocator is created. if (options.config.graph_options().build_cost_model() > 0) { EnableCPUAllocatorFullStats(true); } std::vector devices; const Status s = DeviceFactory::AddDevices( options, "/job:localhost/replica:0/task:0", &devices); if (!s.ok()) { LOG(ERROR) << s; return nullptr; } DirectSession* session = new DirectSession(options, new DeviceMgr(devices), this); { mutex_lock l(sessions_lock_); sessions_.push_back(session); } return session; }
从上可知,
DirectSessionFactory::NewSession()
不单单只是创建
DirectSession
,还要完成相关的
devices
收集。相关设备收集通过调用
DeviceFactory::AddDevices()
来完成,相关代码在
device_factory.cc
中,如下所示:
Status DeviceFactory::AddDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) { //先获取CPU对应的设备工厂cpu_factory auto cpu_factory = GetFactory("CPU"); //创建设备并记录保存到devices TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(options, name_prefix, devices)); ... //遍历device_factories(),创建设备集,包括GPU for (auto& p : device_factories()) { auto factory = p.second.factory.get(); if (factory != cpu_factory) { TF_RETURN_IF_ERROR(factory->CreateDevices(options, name_prefix, devices)); } } return Status::OK();}
从上可知,DeviceFactory::AddDevices()
也是采用工厂模式,主要完成的是遍历device_factories()
,然后调用每个factory
中的CreateDevices
接口,创建设备并把相应指针存储到devices vector
中。在此,有几个接口函数需要说明下:
DeviceFactory* DeviceFactory::GetFactory(const string& device_type) { //根据device_type查找对应的DeviceFactory auto it = device_factories().find(device_type); if (it == device_factories().end()) { return nullptr; } return it->second.factory.get();}
DeviceFactory::GetFactory()
主要是通过输入参数
device_type
在
device_factories()
中查找对应的设备工厂
DeviceFactory
。device_factories()的本质类似于session_factories(),函数里创建了一个静态的unordered_map,表示device_type到FactoryItem的映射。FactoryItem是个结构体,包括factory指针和相应的优先级。
struct FactoryItem { std::unique_ptr factory; int priority;};std::unordered_map& device_factories() { static std::unordered_map* factories = new std::unordered_map; return *factories;}
跟SessionFactory::Register()类似,既然涉及到对unordered_map的读取,那么肯定存在对key和value的存储操作。该操作主要是通过
DeviceFactory::Register()
接口来完成相关
DeviceFactory
的注册。在TensorFlow中,专门为此进行了宏定义,如下所示:
#define REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, ...)INTERNAL_REGISTER_LOCAL_DEVICE_FACTORY(device_type, device_factory, __COUNTER__, ##__VA_ARGS__)
以下是相关DeviceFactory注册的部分代码,分别在threadpool_device_factory.cc和gpu_device_factory.cc文件中。从这也可以看出,
GPUDeviceFactory
的优先级要明显高于
ThreadPoolDeviceFactory
。
REGISTER_LOCAL_DEVICE_FACTORY("CPU", ThreadPoolDeviceFactory, 60);REGISTER_LOCAL_DEVICE_FACTORY("CPU", GPUCompatibleCPUDeviceFactory, 70);REGISTER_LOCAL_DEVICE_FACTORY("GPU", GPUDeviceFactory, 210);
CreateDevices 遍历
device_factories()
里
unordered_map
的时候,都会让每个
DeviceFactory
调用设备创建
CreateDevices()
,并存储到
std::vector* devices
中。下面以ThreadPoolDeviceFactory::CreateDevices为例来介绍其具体细节。
Status ThreadPoolDeviceFactory::CreateDevices(const SessionOptions& options, const string& name_prefix, std::vector* devices) override { int n = 1; auto iter = options.config.device_count().find("CPU"); if (iter != options.config.device_count().end()) { n = iter->second; } for (int i = 0; i < n; i++) { string name = strings::StrCat(name_prefix, "/device:CPU:", i); devices->push_back(new ThreadPoolDevice( options, name, Bytes(256 << 20), DeviceLocality(), cpu_allocator())); } return Status::OK(); }
从第9行可以看出,
ThreadPoolDeviceFactory::CreateDevices
主要是创建了一个
ThreadPoolDevice
,
ThreadPoolDevice
中存有一个
allocator
用来分配和释放内存。该
allocator
由
cpu_allocator()
来获取。
cpu_allocator()
相关定义在
allocator.cc
文件中,如下所示:
Allocator* cpu_allocator() { static Allocator* cpu_alloc = AllocatorRegistry::Global()->GetAllocator(); if (cpu_allocator_collect_full_stats && !cpu_alloc->TracksAllocationSizes()) { cpu_alloc = new TrackingAllocator(cpu_alloc, true); } return cpu_alloc;}
不知道大家现在对第2行的接口有没有丝丝熟悉感?对的,在Allocator(基础篇)已经对
AllocatorRegistry
相关接口进行了说明。从第2行可看出,每次执行
AllocatorRegistry::Global()->GetAllocator()
都会返回
AllocatorRegistry
当前优先级最高的
allocator
。如果该
allocator
想收集状态但是
TracksAllocationSizes()
又为false,那么就可以对该
allocator
进行封装,在此基础上创建一个
TrackingAllocator
即可进行记录追踪。当然,如果想知道底层是在哪里使用了
BFCAllocator
,则推荐阅读
GPUDeviceFactory::CreateDevices
,这里不再做过多说明。
4. 总结本篇根据前端
Python
层、
Swig
以及后端
C++
层三个方面来详细说明
sess=tf.Session()
底部实现原理。前端
Python
层介绍了
Session
和
BaseSession
等的概念和相互联系;
Swig
主要完成将
Python
层的
session
创建转发到
C++
层的
session
创建;后端
C++
层
session
创建根据
SessionOptions
找到相应的
SessionFactory
来执行
NewSession
操作。而
NewSession
函数不仅要完成
session
创建,还要根据
DeviceFactory
完成设备的创建并收集,这里自然离不开各种
allocator
。