Tensorflow中设计模式-反射模式
作为一个优秀的深度学习框架,tensorflow支持各种各样的设备和op,而且支持用户自定义op。用户在自定义一个op后,这个op如何注册进tensorflow框架,用户在使用这个op时,tensorflow是如何知道这个op的呢?今天可以聊聊tensorflow底层代码中广泛使用的反射设计模式。
在开始聊tensorflow之前,我们先了解一下什么是反射模式。在很多时候,我们创建要创建的对象,在代码编译前都已经写好,但是在有的场合下,只有在程序运行的时候,才知道要创建对象的类型,很多时候需要配合一个工厂模式,写一个方法,根据传入的对象的类型信息,去new一个新的对象。当需要新加一个类型时,对工厂类进行修改,在工厂方法类中新增一行,然后重新编译,不符合开闭原则。有没有一种方式,当增加一种新类型时,该方式能够自动将类信息注册进系统内呢,答案就是反射模式。在java中,class.forname使得java天生支持反射模式。但是在c++中,没有类似的实现,需要用户自己去实现。本文借tensorflow相关代码来了解一下反射模式,顺便也学习一下tensorflow中相关模块的底层实现,本文以tensorflow2.0版本为例。
首先来看op的创建,我们知道在定义op时,至少会用到两个宏:
REGISTER_OP和REGISTER_KERNEL_BUILDER。
REGISTER_OP主要用来定义op的描述,例如输入参数,输出参数,属性参数等。先看源码:
可以看出宏先生成一个OpDefBuilderWrapper对象,通过将该对象作为参数调用OpDefBuilderReceiver类构造函数,最终生成一个静态OpDefBuilderReceiver对象。重点查看OpDefBuilderReceiver的构造函数。
可以看出构造函数注册了一个lamda表达式,该表达式最终返回一个OpDefBuilder对象,这个对象表示一个op的描述。
重点看一下287行代码:
245行定义了一个静态对象,该对象提供了op描述注册和建立相关方法:
重点是Register注册方法和LookUp查询方法。
Register主要注册生成op描述的方法。Lookup返回一个OpDef,即是一个op的描述。
可以看出当注册op描述时,只要使用REGISTER_OP宏,就能自动定义一个op的描述。而无需修改已有的代码。
再来看REGISTER_KERNEL_BUILDER宏
1462行,可以看出该宏主要生成一个OpKernelRegistrar对象,其目的是调用该类的构造函数。
1559行调用InitInternal函数。
1206行表示将生成OpKernel对象的工厂emplace到1074行的registry对象里面。
当用户顶一个opkernal时,只需调用REGISTER_KERNEL_BUILDER宏,就能自动将生成自己对象的工厂注册进框架。当用户需要生成opkernal时,
1472行获取工厂,赋给registeration.
在1515行创建一个opkernal。
在tensorflow中,还有类似设计,例如device。Tensorflow支持各种设备,如果用户需要支持一种新设备,只需调用REGISTER_LOCAL_DEVICE_FACTORY宏即可,该宏实现的方式基本和op的宏类似。