Layer的设计思路
Layer的抽象
如果将深度学习中的所有层分为两类, 那么肯定是"带权重"的层和"不带权重"的层。
基于层的共性,我们定义了一个Layer的基类,提供了一些基本接口,并可以通过继承和多态机制实现不同类型的Layer。
具体来说,该类包括以下几个成员函数:
-
构造函数
Layer(std::string layer_name)
,用于创建一个Layer对象并设置该层的名称。 -
virtual ~Layer() = default
,虚析构函数,在派生类中可以通过override关键字重新定义。 -
virtual InferStatus Forward(const std::vector<std::shared_ptr<Tensor<float>>> &inputs, std::vector<std::shared_ptr<Tensor<float>>> &outputs)
,前向传播函数,将输入tensor作为参数,计算输出tensor。 -
virtual const std::vector<std::shared_ptr<Tensor<float>>> &weights() const
, 返回当前层的权重数组。 -
virtual const std::vector<std::shared_ptr<Tensor<float>>> &bias() const
, 返回当前层的偏置数组。 -
virtual void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights)
,设置当前层的权重数组。 -
virtual void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias)
,设置当前层的偏置数组。 -
virtual void set_weights(const std::vector<float> &weights)
,将权重数据类型转换为shared_ptr
后调用上述函数。 -
virtual void set_bias(const std::vector<float> &bias)
,将偏置数据类型转换为shared_ptr
后调用上述函数。 -
virtual const std::string &layer_name() const
,返回当前层的名称。
而成员变量只有一个,即
std::string layer_name_
,Layer的名称
为什么定义成虚函数
在神经网络中,不同的层具有不同的结构和运算方式,因此需要不同的函数来实现它们。使用虚函数的方法可以将这些不同的函数封装到一个基类中,并通过多态机制来实现不同类型的层的动态绑定。
具体来说,当使用基类指针或引用调用虚函数时,程序会根据对象的动态类型(即实际指向的派生类类型)来选择相应的函数实现。这就使得不同类型的层可以通过共同的接口进行调用,从而提高了代码的可维护性和可扩展性。
此外,使用虚函数还可以方便地定义抽象类,即只声明虚函数但不提供实现的类。这可以为子类提供一个规范化的接口,要求其必须重写某些接口以满足特定的需求。这种机制可以有效避免在大型工程中出现微小的差错而导致底层实现不符合最终需求的问题。
带权重Layer的实现
我们把Layer基类来表示不带参数的Layer,并且通过继承该Layer基类的方式来定义了一个带参数的层ParamLayer子类,在ParamLayer中定义了成员变量bias和weights。
ParamLayer是具有可调参数的神经网络层实现,包括初始化权重和偏置的函数、重载读写权重和偏置的函数,以及保存权重和偏置的成员变量。
具体来说,该类包括以下几个成员函数和成员变量:
-
构造函数
ParamLayer(const std::string &layer_name)
,用于创建一个ParamLayer对象并设置该层的名称。 -
void InitWeightParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width)
,用于初始化权重参数。 -
void InitBiasParam(const uint32_t param_count, const uint32_t param_channel, const uint32_t param_height, const uint32_t param_width)
,用于初始化偏置参数。 -
const std::vector<std::shared_ptr<Tensor<float>>> &weights() const override
,重载虚函数weights()
,返回保存权重参数的成员变量weights_
。 -
const std::vector<std::shared_ptr<Tensor<float>>> &bias() const override
,重载虚函数bias()
,返回保存偏置参数的成员变量bias_
。 -
void set_weights(const std::vector<float> &weights) override
,重载虚函数set_weights()
,将权重数据类型转换为shared_ptr
后存储在成员变量weights_中。 -
void set_bias(const std::vector<float> &bias) override
,重载虚函数set_bias()
,将偏置数据类型转换为shared_ptr
后存储在成员变量bias_
中。 -
void set_weights(const std::vector<std::shared_ptr<Tensor<float>>> &weights) override
,重载虚函数set_weights()
,将参数复制到成员变量weights_
中。 -
void set_bias(const std::vector<std::shared_ptr<Tensor<float>>> &bias) override
,重载虚函数set_bias()
,将参数复制到成员变量bias_
中。 -
成员变量
std::vector<std::shared_ptr<Tensor<float>>> weights_
,保存ParamLayer的权重参数。 -
成员变量
std::vector<std::shared_ptr<Tensor<float>>> bias_
,保存ParamLayer的偏置参数。
ParamLayer通过继承Layer类实现了一些共同接口,并在此基础上扩展了更多函数和成员,可以方便地实现带有参数的神经网络层。
Layer的注册机制
为了实现注册和创建神经网络层,并在运行时动态地生成不同类型的神经网络层,定义了两个类:LayerRegisterer和LayerRegistererWrapper。
具体来说,LayerRegisterer类提供了三个静态函数和一个静态成员变量:
-
typedef ParseParameterAttrStatus (*Creator)(const std::shared_ptr<RuntimeOperator> &op, std::shared_ptr<Layer> &layer)
:定义了一个函数指针类型Creator,用于指向具体神经网络层的函数。 -
typedef std::map<std::string, Creator> CreateRegistry
:定义了一个映射类型CreateRegistry,用于保存层类型和对应创建函数的映射关系。 -
static void RegisterCreator(const std::string &layer_type, const Creator &creator)
:将层类型和创建函数的映射关系注册到CreateRegistry中。 -
static std::shared_ptr<Layer> CreateLayer(const std::shared_ptr<RuntimeOperator> &op)
:根据输入的op对象创建相应的神经网络层。 -
static CreateRegistry &Registry()
:返回当前已经注册的所有层类型和创建函数的映射关系。
RuntimeOperator是计算图的某个计算节点,里面保存了计算节点所需的参数等信息,具体介绍请看3.Graph.md
。
而LayerRegistererWrapper类则提供了一个构造函数,用于将某一种类型的神经网络层和其创建函数注册到LayerRegisterer中,如下所示。
class LayerRegistererWrapper {
public:
LayerRegistererWrapper(const std::string &layer_type, const LayerRegisterer::Creator &creator) {
LayerRegisterer::RegisterCreator(layer_type, creator);
}
};
在LayerRegisterer类中,通过维护一个键值对(<std::string, Creator>
)CreateRegistry
,管理Layer注册表,在注册和查找Layer时都要先检查一下是否注册,如果未注册输出错误信息。
为什么要把成员函数定义为静态的
静态函数与类相关联,而不是与类的对象相关。因此,静态函数可以在没有创建类的实例的情况下调用,从而方便地提供一些辅助函数或管理函数,例如工厂方法、单例等。
LayerRegisterer和LayerRegistererWrapper中定义的所有函数都是静态的,主要原因是这些函数需要全局地维护层类型和创建函数的映射关系,并控制新层类型的注册和创建过程。使用静态函数可以使得这些功能在整个程序中被共享和访问,同时避免了由于对象实例的含糊不清而导致的编码错误。
另外需要注意的是,静态函数可以直接使用静态成员变量,不需要通过对象来访问,这使得这些静态函数可以更容易地协同工作,并兼顾了效率和灵活性。
阅读的代码
- include
- layer
- abstract
- layer_factory.hpp
- layer.hpp
- param_layer.hpp
- abstract
- layer
- source
- layer
- abstract
- layer.cpp
- layer_factory.cpp
- param_layer.cpp
- abstract
- layer