PyTorch AMP autocast源码解析

PyTorch框架AMP(Automatic Mixed Precision)模块的源码早在一年之前便已初探过,然而,最近有人向我咨询某些源码问题时,一些自认为当时已了然于胸的细节,如今已回忆不起来。果然,好记性还是不如烂笔头,于是趁着这次重新梳理,将源码中重要的细节记录下来(针对PyTorch1.9版本源码),同时也希望能够帮助到对这部分源码有疑惑的朋友。

PyTorch AMP主要的功能分为Autocasting和Gradient Scaling,这两部分的算法原理以后有空再介绍。本文主要分析Autocasting功能的技术细节,Gradient Scaling不作探究,主要因为Autocasting功能涉及的部分C++代码比较有趣,而Gradient Scaling为Python实现,相对简单。

1、什么是Autocasting ?

autocast,顾名思义,即自动进行cast操作。在PyTorch中,指无需人为干预的情况下,自动对算子的输入进行cast操作。使用方法可参考PyTorch AMP autocast官方文档

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:

    # Enables autocasting for the forward pass (model + loss)
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # Exits the context manager before backward()





  • 在PyTorch算子调用栈上某一层插入处理函数
  • 在处理函数中对算子的输入进行必要操作

2.1 利用Autocast dispatchkey注册处理函数

在PyTorch的基础架构中,dispatcher可以说是其最核心的部分,简单来说,dispatcher是一个根据输入的属性决定调用哪个函数的系统。关于dispatcher的详细说明,建议阅读Edward Yang的这篇优秀文章
利用该文章中提到的 Operator Registration API 即可往dispatch table中插入函数指针。PyTorch已经定义了Autocast dispatch key,直接利用TORCH_LIBRARY_IMPL(aten, Autocast, m)即可达到插入处理函数的目的。因此,autocast_mode.cpp中后半部分源码都是在利用TORCH_LIBRARY_IMPL接口为多个算子插入处理函数:

TORCH_LIBRARY_IMPL(aten, Autocast, m) {
 KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16)
 KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), fp16)
 ... ...

2.2 在处理函数中对算子进行autocast

根据Autocast key的优先级,利用*TORCH_LIBRARY_IMPL(aten, Autocast, m)*注册的处理函数会在进入对应算子kernel前被调用,类似python的装饰器一样,处理函数会拿到接下来要运行的函数已经函数的所有参数,因此,处理函数的任务简单来说就是:对算子参数进行进行过滤(根据autocast策略选择性进行类型转换),再往后继续执行

// Policies correspond to op categories that need code-divergent handling.
// Wrapper templates below are specialized based on a policy template parameter.
enum class CastPolicy : uint8_t {
  fp16 = 0, // Cast all inputs to at::kHalf before running the op.
  fp32, // Cast all inputs to at::kFloat before running the op.
  fp32_set_opt_dtype, // Treats functions (like softmax) that
                      //   1. we'd like to run in fp32 and
                      //   2. have a c10::optional<ScalarType> arg that controls the output type.
                      // fp32_set_opt_dtype wrappers' policy is:  if the output type is already set,
                      // don't touch it, otherwise, set it to at::kFloat.
   fp32_append_dtype, // Treats functions (like norm) that
                      //   1. we'd like to run in fp32 and
                      //   2. have some overloads that accept an output type and other overloads that don't.
                      // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
                      // The wrapper policy is:  append at::kFloat to the args, and redispatch to the
                      // type-aware overload.
   promote, // Run in the widest dtype among several args.


  • fp16 :该名单内的算子,在算子执行前将所有输入cast成fp16
  • fp32:该名单内的算子,在算子执行前将所有输入cast成fp32
  • fp32_set_opt_dtype:该名单内的算子,如果有指定output的dtype,那么autocast什么都不做;否则将output指定为fp32
  • fp32_append_dtype:该名单内的算子,在其参数列表最后增加fp32的dtype参数,然后分发到重载版算子
  • promote:该名单内的算子,将采用参数列表中精度最高的dtype进行运算



... ...
Templates to provide wrapper functions
I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to extract args and return type.
(see also
This strategy uses an exterior "WrapFunction" that extracts arguments on behalf of
(in my case several specializations of) an interior "WrapFunction_".
Interior WrapFunction_ specializations are defined for each CastPolicy.

// Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy
template<CastPolicy policy, class Redispatch, Redispatch* F, class Ret, class ArgList> struct WrapFunction_ {};

// CastPolicy::fp16
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    return (*F)(cached_cast(at::kHalf, args)...);

// CastPolicy::fp32
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    return (*F)(cached_cast(at::kFloat, args)...);

// CastPolicy::fp32_set_opt_dtype
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    if (firstarg_is_eligible(args...)) {
      return (*F)(set_opt_dtype(at::kFloat, args)...);
    } else {
      // If ineligible, calls F with unaltered args.  Does not set opt dtype, because setting
      // opt dtype explicitly may interfere with internal implicit promotion decisions.
      return (*F)(args...);

// CastPolicy::fp32_append_dtype
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    at::ScalarType out_type = type_from_firstarg(at::kFloat, args...);
    return (*F)(args..., out_type);

// CastPolicy::promote
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::promote, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    auto to_type = promote_type(at::kHalf, args...);
    return (*F)(cached_cast(to_type, args)...);

// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating core/boxing/impl/WrapFunctionIntoFunctor.h)
template<CastPolicy policy,
         class Registered, // The signature for which we're registering.  The dispatcher's calling code invokes our
                           // registered functions with arguments matching Registered, so we register
                           // WrapFunction_::call methods with a matching signature to properly field those arguments.
                           // guts::function_traits below extracts return_type and parameter_types from Registered,
                           // which WrapFunction_ templates above use to declare their call methods.
         class Redispatch, // The signature for the function we're redispatching to.  In most cases this is the same
                           // as Registered, but for some ops (for example, ops where we append a dtype) it's useful
                           // to redispatch to a function with a different signature.
         Redispatch* F>    // The actual function we're redispatching to.
struct WrapFunction final {
  using type = WrapFunction_<policy,
                             typename guts::function_traits<Registered>::return_type,
                             typename guts::function_traits<Registered>::parameter_types>;
... ...



int add_one(int a) {
  int res = a + 1;
  return res;

float add_two(float a, float b) {
  float res = a + b + 1;
  return res;

long add_three(long a, long b, long c) {
  long res = a + b + c + 1;
  return res;



... ...
template<void* f, typename Ret, typename... Args>
Ret wrapper(Args... args) {
  // TODO: wrapper func can do something for args here ...
  return (*f)(args...);

int main() {
  wrapper<&add_one, int>(-10);
  return 0;


  • void* 不能作为non-type template的模版参数
  • void* 未强转到具体类型之前不能解引用

关于void解引用的限制,相信大家都在书中看到过,至于背后的原因,不知道又有几人知晓。void指向的内容无非两种,object或function,对于object,如果不知道具体类型,便无法知其内存布局,因此该限制不难理解,不过对于function,生成的汇编代码只要知道其地址就可以直接call,为什么还要知道其函数类型呢?关于这个疑问我在网上没找到答案,不过思考再三后我有一个自己的理解:在Calling Convention的规定中,函数参数的传递是由caller完成的,如果不知道函数的类型,也就不知道其参数的大小和个数,因此无法完成参数的构造和传递。


... ...
template<typename Ret, typename... Args>
Ret wrapper(void* f, Args... args) {
  // TODO: wrapper func can do something for args here ...
  typedef Ret (*ft)(Args...);
  return ((ft)f)(args...);

int main() {
  wrapper<int, int>((void*)&add_one, -10);
  return 0;


... ...
template<typename FT, FT* f, typename Ret, typename... Args>
Ret wrapper(Args... args) {
  // TODO: wrapper func can do something for args here ...
  return f(args...);

int main() {
  wrapper<int(int), &add_one, int>(-10);
  return 0;

经过优化,这次函数指针在编译期就可以获得。不过,上述代码仍然存在可优化的空间:第一个模板参数FT已经包含了函数的入参和返回值类型,需要用户再次提供返回值类型稍显累赘,理论上完全可以通过Type Traits提取返回值的类型:

template<typename FT>
struct func_traits{};

template<typename Ret, typename... Args>
struct func_traits<Ret(Args...)>{
  using func_type = Ret(Args...);
  using ret_type = Ret;
  // you can get other traits here

template<typename FT, FT* f, typename... Args>
//默认情况下,C++ 语言假定通过作用域运算符访问的名字不是类型。
//我们通过使用关键字 typename 来实现这一点。
typename func_traits<FT>::ret_type wrapper(Args... args) {
  // TODO: wrapper func can do something for args here ...
  return f(args...);

int main() {
  wrapper<int(int), &add_one>(-10);
  return 0;


... ...
template<typename T>
T my_abs(T t) {
  if(t < 0){
    return -t;
  else {
    return t;

template<typename FT>
struct func_traits{};

template<typename Ret, typename... Args>
struct func_traits<Ret(Args...)>{
  using func_type = Ret(Args...);
  using ret_type = Ret;
  // you can get other traits here

template<typename FT, FT* f, typename... Args>
typename func_traits<FT>::ret_type wrapper(Args... args) {
  return f(my_abs(args)...); //模板参数包unpack

int main() {
  wrapper<int(int), &add_one>(-10);
  wrapper<float(float, float), &add_two>(-10, -10);
  wrapper<long(long, long, long), &add_three>(-10, -10, -10);
  return 0;

值得注意的是,"my_abs(args)…"是模板参数unpack语法,展开应该是“my_abs(arg0), … … , my_abs(argn)”的形式。

评论 3




当前余额3.43前往充值 >
领取后你会自动成为博主和红包主的粉丝 规则
钱包余额 0


