PyTorch AMP autocast源码解析

PyTorch框架AMP(Automatic Mixed Precision)模块的源码早在一年之前便已初探过,然而,最近有人向我咨询某些源码问题时,一些自认为当时已了然于胸的细节,如今已回忆不起来。果然,好记性还是不如烂笔头,于是趁着这次重新梳理,将源码中重要的细节记录下来(针对PyTorch1.9版本源码),同时也希望能够帮助到对这部分源码有疑惑的朋友。

PyTorch AMP主要的功能分为Autocasting和Gradient Scaling,这两部分的算法原理以后有空再介绍。本文主要分析Autocasting功能的技术细节,Gradient Scaling不作探究,主要因为Autocasting功能涉及的部分C++代码比较有趣,而Gradient Scaling为Python实现,相对简单。

1、什么是Autocasting ?

autocast,顾名思义,即自动进行cast操作。在PyTorch中,指无需人为干预的情况下,自动对算子的输入进行cast操作。使用方法可参考PyTorch AMP autocast官方文档

# Creates model and optimizer in default precision
model = Net().cuda()
optimizer = optim.SGD(model.parameters(), ...)

for input, target in data:
    optimizer.zero_grad()

    # Enables autocasting for the forward pass (model + loss)
    with autocast():
        output = model(input)
        loss = loss_fn(output, target)

    # Exits the context manager before backward()
    loss.backward()
    optimizer.step()

autocast功能的api作为python上下文管理器使用,在上下文作用域内的算子,PyTorch会根据内部的策略,选择性对算子的输入进行操作(绝大部分是进行cast,部分算子将目标dtype信息作为参数传给算子),从而达到改变算子实际计算精度的功能。

2、Autocasting内部原理

关于autocast功能的核心代码位于https://github.com/pytorch/pytorch/blob/v1.9.0/aten/src/ATen/autocast_mode.cpp,这部分代码数量虽然不多,但对于PyTorch新手或者C++新手来说,还是存在一定障碍。为了让新手们能够快速了解这部分原理,下面我会从功能实现者的角度来剖析这部分源码。

autocast的要做的事情,简单来说就是:在进入算子计算之前,选择性的对输入进行cast操作。为了做到这点,在PyTorch1.9版本的架构上,可以分解为如下两步:

  • 在PyTorch算子调用栈上某一层插入处理函数
  • 在处理函数中对算子的输入进行必要操作

2.1 利用Autocast dispatchkey注册处理函数

在PyTorch的基础架构中,dispatcher可以说是其最核心的部分,简单来说,dispatcher是一个根据输入的属性决定调用哪个函数的系统。关于dispatcher的详细说明,建议阅读Edward Yang的这篇优秀文章http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/
利用该文章中提到的 Operator Registration API 即可往dispatch table中插入函数指针。PyTorch已经定义了Autocast dispatch key,直接利用TORCH_LIBRARY_IMPL(aten, Autocast, m)即可达到插入处理函数的目的。因此,autocast_mode.cpp中后半部分源码都是在利用TORCH_LIBRARY_IMPL接口为多个算子插入处理函数:

TORCH_LIBRARY_IMPL(aten, Autocast, m) {
 KERNEL(ADD_NS(_convolution), "_convolution.deprecated", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool), fp16)
 KERNEL(ADD_NS(_convolution), "_convolution", Tensor (const Tensor &, const Tensor &, const c10::optional<Tensor>&, IntArrayRef, IntArrayRef, IntArrayRef, bool, IntArrayRef, int64_t, bool, bool, bool, bool), fp16)
 ... ...

2.2 在处理函数中对算子进行autocast

根据Autocast key的优先级,利用*TORCH_LIBRARY_IMPL(aten, Autocast, m)*注册的处理函数会在进入对应算子kernel前被调用,类似python的装饰器一样,处理函数会拿到接下来要运行的函数已经函数的所有参数,因此,处理函数的任务简单来说就是:对算子参数进行进行过滤(根据autocast策略选择性进行类型转换),再往后继续执行
PyTorch的autocast策略如下:

// Policies correspond to op categories that need code-divergent handling.
// Wrapper templates below are specialized based on a policy template parameter.
enum class CastPolicy : uint8_t {
  fp16 = 0, // Cast all inputs to at::kHalf before running the op.
  fp32, // Cast all inputs to at::kFloat before running the op.
  fp32_set_opt_dtype, // Treats functions (like softmax) that
                      //   1. we'd like to run in fp32 and
                      //   2. have a c10::optional<ScalarType> arg that controls the output type.
                      // fp32_set_opt_dtype wrappers' policy is:  if the output type is already set,
                      // don't touch it, otherwise, set it to at::kFloat.
   fp32_append_dtype, // Treats functions (like norm) that
                      //   1. we'd like to run in fp32 and
                      //   2. have some overloads that accept an output type and other overloads that don't.
                      // fp32_append_dtype wrappers wrap the overloads that don't have an output dtype.
                      // The wrapper policy is:  append at::kFloat to the args, and redispatch to the
                      // type-aware overload.
   promote, // Run in the widest dtype among several args.
};

根据上述策略,将需要进行autocast的算子划分了如下名单:

  • fp16 :该名单内的算子,在算子执行前将所有输入cast成fp16
  • fp32:该名单内的算子,在算子执行前将所有输入cast成fp32
  • fp32_set_opt_dtype:该名单内的算子,如果有指定output的dtype,那么autocast什么都不做;否则将output指定为fp32
  • fp32_append_dtype:该名单内的算子,在其参数列表最后增加fp32的dtype参数,然后分发到重载版算子
  • promote:该名单内的算子,将采用参数列表中精度最高的dtype进行运算

观察该名单可以发现,fp32、fp32_set_opt_dtype和fp32_append_dtype这三种类型的算子都是要进行fp32计算,那为何要分成三种策略呢?个人认为,之所以划分成三种策略,是由于不同算子在框架中的实现方式差异导致。在fp32名单中的算子,需要对输入强制cast才能指定计算类型,而fp32_set_opt_dtype和fp32_append_dtype名单内的算子,计算类型大多和输出类型有关,所以增加了参数控制输出类型,即控制计算类型。

名单中的不同策略,对应着在处理函数中对算子的参数做不同逻辑的处理。autocast_mode.cpp前半部分代码正是针对几种autocast策略的处理函数的实现:

... ...
/********************************************************************************************************
Templates to provide wrapper functions
I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to extract args and return type.
(see also https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer)
This strategy uses an exterior "WrapFunction" that extracts arguments on behalf of
(in my case several specializations of) an interior "WrapFunction_".
Interior WrapFunction_ specializations are defined for each CastPolicy.
********************************************************************************************************/

// Base template for WrapFunction_, which is specialized to contain a "call" method each CastPolicy
template<CastPolicy policy, class Redispatch, Redispatch* F, class Ret, class ArgList> struct WrapFunction_ {};

// CastPolicy::fp16
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp16, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    return (*F)(cached_cast(at::kHalf, args)...);
  }
};

// CastPolicy::fp32
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    return (*F)(cached_cast(at::kFloat, args)...);
  }
};

// CastPolicy::fp32_set_opt_dtype
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_set_opt_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    if (firstarg_is_eligible(args...)) {
      return (*F)(set_opt_dtype(at::kFloat, args)...);
    } else {
      // If ineligible, calls F with unaltered args.  Does not set opt dtype, because setting
      // opt dtype explicitly may interfere with internal implicit promotion decisions.
      return (*F)(args...);
    }
  }
};

// CastPolicy::fp32_append_dtype
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::fp32_append_dtype, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    at::ScalarType out_type = type_from_firstarg(at::kFloat, args...);
    return (*F)(args..., out_type);
  }
};

// CastPolicy::promote
template<class Redispatch, Redispatch* F, class Ret, class... Args>
struct WrapFunction_<CastPolicy::promote, Redispatch, F, Ret, guts::typelist::typelist<Args...>> {
  static Ret call(Args... args) {
    c10::impl::ExcludeDispatchKeyGuard no_autocast(DispatchKey::Autocast);
    auto to_type = promote_type(at::kHalf, args...);
    return (*F)(cached_cast(to_type, args)...);
  }
};

// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating core/boxing/impl/WrapFunctionIntoFunctor.h)
template<CastPolicy policy,
         class Registered, // The signature for which we're registering.  The dispatcher's calling code invokes our
                           // registered functions with arguments matching Registered, so we register
                           // WrapFunction_::call methods with a matching signature to properly field those arguments.
                           // guts::function_traits below extracts return_type and parameter_types from Registered,
                           // which WrapFunction_ templates above use to declare their call methods.
         class Redispatch, // The signature for the function we're redispatching to.  In most cases this is the same
                           // as Registered, but for some ops (for example, ops where we append a dtype) it's useful
                           // to redispatch to a function with a different signature.
         Redispatch* F>    // The actual function we're redispatching to.
struct WrapFunction final {
  using type = WrapFunction_<policy,
                             Redispatch,
                             F,
                             typename guts::function_traits<Registered>::return_type,
                             typename guts::function_traits<Registered>::parameter_types>;
};
... ...

这段代码虽然不长,但值得学习和借鉴的地方却很多。下面尝试从代码贡献者的角度,构造一个同样场景的需求,简化并分析这段代码。

假设有这样一个需求:现有如下三个函数,需要编写一个包装器wrapper,在不改变原有函数实现的基础上,对原有函数的输入进行非负判断和绝对值操作

int add_one(int a) {
  int res = a + 1;
  return res;
}

float add_two(float a, float b) {
  float res = a + b + 1;
  return res;
}

long add_three(long a, long b, long c) {
  long res = a + b + c + 1;
  return res;
}

一个常见的思路是使用模板函数设计wrapper,然后在使用的时候将待包装的函数地址和参数传给wrapper。

在实现wrapper的过程中,我相信会有人会写出类似这样的代码:

... ...
template<void* f, typename Ret, typename... Args>
Ret wrapper(Args... args) {
  // TODO: wrapper func can do something for args here ...
  return (*f)(args...);
}

int main() {
  wrapper<&add_one, int>(-10);
  return 0;
}

看上去十分简洁,其实暗藏两个小Tips:

  • void* 不能作为non-type template的模版参数
  • void* 未强转到具体类型之前不能解引用

关于void解引用的限制,相信大家都在书中看到过,至于背后的原因,不知道又有几人知晓。void指向的内容无非两种,object或function,对于object,如果不知道具体类型,便无法知其内存布局,因此该限制不难理解,不过对于function,生成的汇编代码只要知道其地址就可以直接call,为什么还要知道其函数类型呢?关于这个疑问我在网上没找到答案,不过思考再三后我有一个自己的理解:在Calling Convention的规定中,函数参数的传递是由caller完成的,如果不知道函数的类型,也就不知道其参数的大小和个数,因此无法完成参数的构造和传递。

既然不能使用void*接收不同类型的函数指针,那就必须提供完整的函数类型,因此wrapper可以这么写:

... ...
template<typename Ret, typename... Args>
Ret wrapper(void* f, Args... args) {
  // TODO: wrapper func can do something for args here ...
  typedef Ret (*ft)(Args...);
  return ((ft)f)(args...);
}

int main() {
  wrapper<int, int>((void*)&add_one, -10);
  return 0;
}

这种写法虽然可以达到目的,但不够优美,函数指针理论上在编译期就可以得到,无需在运行时传入,针对这一点,代码可以进一步优化:

... ...
template<typename FT, FT* f, typename Ret, typename... Args>
Ret wrapper(Args... args) {
  // TODO: wrapper func can do something for args here ...
  return f(args...);
}

int main() {
  wrapper<int(int), &add_one, int>(-10);
  return 0;
}

经过优化,这次函数指针在编译期就可以获得。不过,上述代码仍然存在可优化的空间:第一个模板参数FT已经包含了函数的入参和返回值类型,需要用户再次提供返回值类型稍显累赘,理论上完全可以通过Type Traits提取返回值的类型:

template<typename FT>
struct func_traits{};

//通过类模板部分特化,利用编译器推导出函数的返回值类型
template<typename Ret, typename... Args>
struct func_traits<Ret(Args...)>{
  using func_type = Ret(Args...);
  using ret_type = Ret;
  // you can get other traits here
};


template<typename FT, FT* f, typename... Args>
//默认情况下,C++ 语言假定通过作用域运算符访问的名字不是类型。
//因此,如果我们希望使用一个模板类型参数的类型成员,就必须显式告诉编译器该名字是一个类型。
//我们通过使用关键字 typename 来实现这一点。
typename func_traits<FT>::ret_type wrapper(Args... args) {
  // TODO: wrapper func can do something for args here ...
  return f(args...);
}

int main() {
  wrapper<int(int), &add_one>(-10);
  return 0;
}

func_traits的实现利用类模板的偏特化推导出函数的返回值类型,避免了用户再次传入Ret类型的限制。至此,只剩下对函数的入参做非负判断和绝对值操作了:

... ...
template<typename T>
T my_abs(T t) {
  if(t < 0){
    return -t;
  }
  else {
    return t;
  }
}

template<typename FT>
struct func_traits{};

template<typename Ret, typename... Args>
struct func_traits<Ret(Args...)>{
  using func_type = Ret(Args...);
  using ret_type = Ret;
  // you can get other traits here
};


template<typename FT, FT* f, typename... Args>
typename func_traits<FT>::ret_type wrapper(Args... args) {
  return f(my_abs(args)...); //模板参数包unpack
}

int main() {
  wrapper<int(int), &add_one>(-10);
  wrapper<float(float, float), &add_two>(-10, -10);
  wrapper<long(long, long, long), &add_three>(-10, -10, -10);
  return 0;
}

值得注意的是,"my_abs(args)…"是模板参数unpack语法,展开应该是“my_abs(arg0), … … , my_abs(argn)”的形式。
到此为止,我们已经完成了wrapper包装器的所有功能。麻雀虽小,五脏俱全,autocast_mode.cpp中wrapper代码的所有核心功能都已包含,相信此时再看源码心中应该了无疑惑了。

  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
PyTorch中的autocast是一个上下文管理器,可以自动将特定操作转换为半精度(float16)运算,以提高模型的训练和推理效率。它可以减少内存使用和加速计算,尤其是在具有深度神经网络的大型模型中。当一个操作被autocast包含的上下文管理器包裹时,PyTorch将自动将其转换为float16运算,以便于GPU进行计算。如果操作的输出需要在其他操作中使用,那么PyTorch会自动将其转换回float32。 使用autocast需要安装PyTorch 1.6及以上版本,并且需要在支持半精度的GPU上运行。在代码中使用autocast时,只需要将需要进行半精度运算的操作放在autocast的上下文管理器中即可。 示例代码: ``` from torch.cuda.amp import autocast with autocast(): output = model(input) loss = criterion(output, target) # 反向传播 loss.backward() ``` 在这个示例中,model和criterion是PyTorch中的模型和损失函数,input和target是训练数据和标签。在with autocast()上下文管理器中,PyTorch会自动将output和loss转换为float16运算。反向传播时,PyTorch会自动将梯度转换回float32,并进行优化。 需要注意的是,有些操作不适合使用半精度运算,如含有大量整数的操作,这些操作应该被排除在autocast的上下文管理器之外。可以使用torch.cuda.amp.autocast(enabled=False)来禁用autocast,或者在上下文管理器中使用torch.cuda.amp.custom_fwd和torch.cuda.amp.custom_bwd来自定义特定操作的半精度实现。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值