MXNet中,gluon.Block类和gluon.HybridBlock类,和Pytorch中的nn.Module类一样,我们通过继承Block类和HybridBlock类可以很灵活的搭建我们自己的网络模型,这里总结一下HybridBlock类使用过程中的一些注意点。
HybridBlock类和Block类的区别
HybridBlock类继承至Block类,所以HybridBlock类有Block类的全部方法和属性。HybridBlock同时支持符号式编程和命令式编程,HybridBlock类可以调用hybridize()方法,从而可以从命令式变为符号式,从而将动态图转化为静态图,提高模型的计算性能和移植性。下面是两者的比较:
| HybridBlock类 | Block类 | |
|---|---|---|
| 重写方法 | __init__()、hybrid_forward(self, F, x, *args, **kwargs) |
__init__()、forwad(self,x,*args) |
| 是否支持符号式 | 是 | 否 |
| 支持输入参数 | 位置式参数、关键字参数 | 只支持位置式参数 |
| 是否支持导出符号模型 | 是 | 否 |
可以看出HybridBlock类除了多支持符号式编程外,和Block基本没什么区别,但是注意到支持输入参数那一栏,hybrid_forward函数还支持输入关键字参数,这点也和Block不一样,下面详细分析一下hybrid_forward的调用过程。
hybrid_forward()分析
当我们构建一个HybridBlock类后,需要重写其|__init__()、hybrid_forward()方法,而我们在源码中可以看到,当一个HybridBlock类进行forward操作时,其流程如下:
__call__()-------->forward()-------->hybrid_forward()
可以看出HybridBlock类是通过forward()方法中来调用hybrid_forward()。由于HybridBlock类中的forward()方法已经被重写过了,所以我们只需要重写hybrid_forward()就可以了,其中forward()函数如下:
def forward(self, x, *args):
"""Defines the forward computation. Arguments can be either
:py:class:`NDArray` or :py:class:`Symbol`."""
if isinstance(x, NDArray):
with x.context as ctx:
if self._active:
本文探讨MXNet中的HybridBlock类与Block类的区别,重点解析HybridBlock的hybrid_forward()函数。HybridBlock支持符号式和命令式编程,通过hybridize()能提升模型性能。hybrid_forward()调用过程涉及参数注册和传递,所有在该方法中注册的参数会作为关键字参数传递给forward(),且在前向运算时自动同步到输入设备。利用这一特性,可以更灵活地构建网络模型。
最低0.47元/天 解锁文章
5591

被折叠的 条评论
为什么被折叠?



