如何查找Tensorflow operation的实现源码
笔者由于工作原因经常需要查阅tensorflow各个operation的实现,然而有些op实在没法猜到它到底定义在那个文件里,全文搜索op的名称又经常搜出来太多的文件,无法快速筛选。
近日笔者研究了一下tensorflow增加新的op的方式,发现了一个查找op实现的好方法。
一般来说,在tensorflow中增加一个新的op需要两步。(以下内容均参考自tensorflow官方文档)
- 定义这个op的接口,并注册到tensorflow中。
在接口中定义中,需要指定这个op的input, output以及相关的一些attribute。定义op接口需要调用宏REGISTER_OP
,例如:
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/shape_inference.h"
using namespace tensorflow;
REGISTER_OP("ZeroOut")
.Input