PyTorch框架AMP(Automatic Mixed Precision)模块的源码早在一年之前便已初探过,然而,最近有人向我咨询某些源码问题时,一些自认为当时已了然于胸的细节,如今已回忆不起来。果然,好记性还是不如烂笔头,于是趁着这次重新梳理,将源码中重要的细节记录下来(针对PyTorch1.9版本源码),同时也希望能够帮助到对这部分源码有疑惑的朋友。
PyTorch AMP主要的功能分为Autocasting和Gradient Scaling,这两部分的算法原理以后有空再介绍。本文主要分析Autocasting功能的技术细节,Gradient Scaling不作探究,主要因为Autocasting功能涉及的部分C++代码比较有趣,而Gradient Scaling为Python实现,相对简单。
1、什么是Autocasting ?
autocast,顾名思义,即自动进行cast操作。在PyTorch中,指无需人为干预的情况下,自动对算子的输入进行cast操作。使用方法可参考PyTorch AMP autocast官方文档 :
# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)
for input, target in data:
optimizer.zero_grad()
# Enables autocasting for the forward pass (model + loss)
with autocast():
output = model(input)
loss = loss_fn(output, target)
# Exits the context manager before backward()
loss.backward()
optimizer.step()
autocast功能的api作为python上下文管理器使用,在上下文作用域内的算子,PyTorch会根据内部的策略,选择性对算子的输入进行操作(绝大部分是进行cast,部分算子将目标dtype信息作为参数传给算子),从而达到改变算子实际计算精度的功能。
2、Autocasting内部原理
关于autocast功能的核心代码位于https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/autocast_mode.cpp,这部分代码数量虽然不多,但对于PyTorch新手或者C++新手来说,还是存在一定障碍。为了让新手们能够快速了解这部分原理,下面我会从功能实现者的角度来剖析这部分源码。
autocast的要做的事情,简单来说就是:在进入算子计算之前,选择性的对输入进行cast操作。为了做到这点,在PyTorch1.9版本的架构上,可以分解为如下两步:
- 在PyTorch算子调用栈上某一层插入处理函数
- 在处理函数中对算子的输入进行必要操作
2.1 利用Autocast dispatchkey注册处理函数
在PyTorch的基础架构中,dispatcher可以说是其最核心的部分,简单来说,dispatcher是一个根据输入的属性决定调用哪个函数的系统。关于dispatcher的详细说明,建议阅读Edward Yang的这篇优秀文章http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/ ,
利用该文章中提到的 Operator Registration API 即可往dispatch table中插入函数指针。PyTorch已经定义了Autocast dispatch key,直接利用TORCH_LIBRARY_IMPL(aten, Autocast, m)即可达到插入处理函数的目的。因此,autocast_mode.cpp中后半部分源码都是在利用TORCH_LIBRARY_IMPL接口为多个算子插入处理函数:
TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16)
KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), fp16)
... ...
2.2 在处理函数中对算子进行autocast
根据Autocast key的优先级,利用*TORCH_LIBRARY_IMPL(aten, Autocast, m)*注册的处理函数会在进入对应算子kernel前被调用,类似python的装饰器一样,处理函数会拿到接下来要运行的函数已经函数的所有参数,因此,处理函数的任务简单来说就是:对算子参数进行进行过滤(根据autocast策略选择性进行类型转换),再往后继续执行。
PyTorch的autocast策略如下:
// Policies correspond to op categories that need code-divergent handling.
// Wrapper templates below are specialized based on a policy template parameter.
enum class CastPolicy : uint8_t {
fp16 = 0, // Cast all inputs to at::kHalf before running the op.
fp32, // Cast all inputs to at::kFloat before running the op.
fp32_set_opt_dtype, // Treats functions (like softmax) that
// 1. we'd like to run in fp32 and
// 2. have a c10::optional<ScalarType> arg that controls the output type.
// fp32_set_opt_dtype wrappers' policy is: if the output type is already set,
// don't touch it, otherwise, set it to at::kFloat.
fp32_append_dtype, // Treats functions (like norm) that
// 1. we'd like to run in fp32 and
// 2. have some overloads that accept an output type and other overloads that don't.
// fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
// The wrapper policy is: append at::kFloat to the args, and redispatch to the
// type-aware overload.
promote, // Run in the widest dtype among several args.
};
根据上述策略,将需要进行autocast的算子划分了如下名单:
- fp16 :该名单内的算子,在算子执行前将所有输入cast成fp16
- fp32:该名单内的算子,在算子执行前将所有输入cast成fp32
- fp32_set_opt_dtype:该名单内的算子,如果有指定output的dtype,那么autocast什么都不做;否则将output指定为fp32
- fp32_append_dtype:该名单内的算子,在其参数列表最后增加fp32的dtype参数,然后分发到重载版算子
- promote:该名单内的算子,将采用参数列表中精度最高的dtype进行运算
观察该名单可以发现,fp32、fp32_set_opt_dtype和fp32_append_dtype这三种类型的算子都是要进行fp32计算,那为何要分成三种策略呢?个人认为,之所以划分成三种策略,是由于不同算子在框架中的实现方式差异导致。在fp32名单中的算子,需要对输入强制cast才能指定计算类型,而fp32_set_opt_dtype和fp32_append_dtype名单内的算子,计算类型大多和输出类型有关,所以增加了参数控制输出类型,即控制计算类型。
名单中的不同策略,对应着在处理函数中对算子的参数做不同逻辑的处理。autocast_mode.cpp前半部分代码正是针对几种autocast策略的处理函数的实现:
... ...
/********************************************************************************************************
Templates to provide wrapper functions
I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to extract args and return type.
(see also https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
This strategy uses an exterior "WrapFunction" that extracts arguments on behalf of
(in my case several specializations of) an interior "WrapFunction_".
Interior WrapFunction_ specializations are defined for each CastPolicy.
********************************************************************************************************/
// Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy
template<CastPolicy policy, class Redispatch, Redispatch* F, class Ret, class ArgList> struct WrapFunction_ {};
// CastPolicy::fp16
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
return (*F)(cached_cast(at::kHalf, args)...);
}
};
// CastPolicy::fp32
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
return (*F)(cached_cast(at::kFloat, args)...);
}
};
// CastPolicy::fp32_set_opt_dtype
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
if (firstarg_is_eligible(args...)) {
return (*F)(set_opt_dtype(at::kFloat, args)...);
} else {
// If ineligible, calls F with unaltered args. Does not set opt dtype, because setting
// opt dtype explicitly may interfere with internal implicit promotion decisions.
return (*F)(args...);
}
}
};
// CastPolicy::fp32_append_dtype
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
at::ScalarType out_type = type_from_firstarg(at::kFloat, args...);
return (*F)(args..., out_type);
}
};
// CastPolicy::promote
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::promote, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
static Ret call(Args... args) {
c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
auto to_type = promote_type(at::kHalf, args...);
return (*F)(cached_cast(to_type, args)...);
}
};
// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating core/boxing/impl/WrapFunctionIntoFunctor.h)
template<CastPolicy policy,
class Registered, // The signature for which we're registering. The dispatcher's calling code invokes our
// registered functions with arguments matching Registered, so we register
// WrapFunction_::call methods with a matching signature to properly field those arguments.
// guts::function_traits below extracts return_type and parameter_types from Registered,
// which WrapFunction_ templates above use to declare their call methods.
class Redispatch, // The signature for the function we're redispatching to. In most cases this is the same
// as Registered, but for some ops (for example, ops where we append a dtype) it's useful
// to redispatch to a function with a different signature.
Redispatch* F> // The actual function we're redispatching to.
struct WrapFunction final {
using type = WrapFunction_<policy,
Redispatch,
F,
typename guts::function_traits<Registered>::return_type,
typename guts::function_traits<Registered>::parameter_types>;
};
... ...
这段代码虽然不长,但值得学习和借鉴的地方却很多。下面尝试从代码贡献者的角度,构造一个同样场景的需求,简化并分析这段代码。
假设有这样一个需求:现有如下三个函数,需要编写一个包装器wrapper,在不改变原有函数实现的基础上,对原有函数的输入进行非负判断和绝对值操作。
int add_one(int a) {
int res = a + 1;
return res;
}
float add_two(float a, float b) {
float res = a + b + 1;
return res;
}
long add_three(long a, long b, long c) {
long res = a + b + c + 1;
return res;
}
一个常见的思路是使用模板函数设计wrapper,然后在使用的时候将待包装的函数地址和参数传给wrapper。
在实现wrapper的过程中,我相信会有人会写出类似这样的代码:
... ...
template<void* f, typename Ret, typename... Args>
Ret wrapper(Args... args) {
// TODO: wrapper func can do something for args here ...
return (*f)(args...);
}
int main() {
wrapper<&add_one, int>(-10);
return 0;
}
看上去十分简洁,其实暗藏两个小Tips:
- void* 不能作为non-type template的模版参数
- void* 未强转到具体类型之前不能解引用
关于void解引用的限制,相信大家都在书中看到过,至于背后的原因,不知道又有几人知晓。void指向的内容无非两种,object或function,对于object,如果不知道具体类型,便无法知其内存布局,因此该限制不难理解,不过对于function,生成的汇编代码只要知道其地址就可以直接call,为什么还要知道其函数类型呢?关于这个疑问我在网上没找到答案,不过思考再三后我有一个自己的理解:在Calling Convention的规定中,函数参数的传递是由caller完成的,如果不知道函数的类型,也就不知道其参数的大小和个数,因此无法完成参数的构造和传递。
既然不能使用void*接收不同类型的函数指针,那就必须提供完整的函数类型,因此wrapper可以这么写:
... ...
template<typename Ret, typename... Args>
Ret wrapper(void* f, Args... args) {
// TODO: wrapper func can do something for args here ...
typedef Ret (*ft)(Args...);
return ((ft)f)(args...);
}
int main() {
wrapper<int, int>((void*)&add_one, -10);
return 0;
}
这种写法虽然可以达到目的,但不够优美,函数指针理论上在编译期就可以得到,无需在运行时传入,针对这一点,代码可以进一步优化:
... ...
template<typename FT, FT* f, typename Ret, typename... Args>
Ret wrapper(Args... args) {
// TODO: wrapper func can do something for args here ...
return f(args...);
}
int main() {
wrapper<int(int), &add_one, int>(-10);
return 0;
}
经过优化,这次函数指针在编译期就可以获得。不过,上述代码仍然存在可优化的空间:第一个模板参数FT已经包含了函数的入参和返回值类型,需要用户再次提供返回值类型稍显累赘,理论上完全可以通过Type Traits提取返回值的类型:
template<typename FT>
struct func_traits{};
//通过类模板部分特化,利用编译器推导出函数的返回值类型
template<typename Ret, typename... Args>
struct func_traits<Ret(Args...)>{
using func_type = Ret(Args...);
using ret_type = Ret;
// you can get other traits here
};
template<typename FT, FT* f, typename... Args>
//默认情况下,C++ 语言假定通过作用域运算符访问的名字不是类型。
//因此,如果我们希望使用一个模板类型参数的类型成员,就必须显式告诉编译器该名字是一个类型。
//我们通过使用关键字 typename 来实现这一点。
typename func_traits<FT>::ret_type wrapper(Args... args) {
// TODO: wrapper func can do something for args here ...
return f(args...);
}
int main() {
wrapper<int(int), &add_one>(-10);
return 0;
}
func_traits的实现利用类模板的偏特化推导出函数的返回值类型,避免了用户再次传入Ret类型的限制。至此,只剩下对函数的入参做非负判断和绝对值操作了:
... ...
template<typename T>
T my_abs(T t) {
if(t < 0){
return -t;
}
else {
return t;
}
}
template<typename FT>
struct func_traits{};
template<typename Ret, typename... Args>
struct func_traits<Ret(Args...)>{
using func_type = Ret(Args...);
using ret_type = Ret;
// you can get other traits here
};
template<typename FT, FT* f, typename... Args>
typename func_traits<FT>::ret_type wrapper(Args... args) {
return f(my_abs(args)...); //模板参数包unpack
}
int main() {
wrapper<int(int), &add_one>(-10);
wrapper<float(float, float), &add_two>(-10, -10);
wrapper<long(long, long, long), &add_three>(-10, -10, -10);
return 0;
}
值得注意的是,"my_abs(args)…"是模板参数unpack语法,展开应该是“my_abs(arg0), … … , my_abs(argn)”的形式。
到此为止,我们已经完成了wrapper包装器的所有功能。麻雀虽小,五脏俱全,autocast_mode.cpp中wrapper代码的所有核心功能都已包含,相信此时再看源码心中应该了无疑惑了。