torch nn Linear 大揭秘——从python到CUDA

本篇博客从顶到底展示了 torch 软件栈的调用路径,以pytorch2.0为例

Python 部分

Linear — PyTorch 2.0 documentation 在深度学习中经常以"构建全连接层"的参与者出现在深度学习模型中,这里的代码是开源的:

class Linear(Module):
    # ...
    def forward(self, input: Tensor) -> Tensor:
        return F.linear(input, self.weight, self.bias)

torch.nn.functional.linear — PyTorch 2.0 documentation 涉及到了 torch.nn.functional.linear 的接口入口。

但具体该接口是如何实现的,如何调用C++以及后面的CUDA库函数的,pytorch 这边刻意地隐藏了,要想知道还要花不少功夫。

Where is the nn.Linear cuda implementation - C++ - PyTorch Forums 幸好这位 ptrblck 大神出面帮我指路,找到了 Aten 相关的 linear 实现

C++部分

现在,我们把C++部分的拿出来:

Tensor linear(const Tensor& input, const Tensor& weight, const c10::optional<Tensor>& bias_opt) {
  // See [Note: hacky wrapper removal for optional tensor]
  auto bias = bias_opt.has_value()
    ? c10::MaybeOwned<Tensor>::borrowed(*bias_opt)
    : c10::MaybeOwned<Tensor>::owned(c10::in_place);

  if (input.is_mkldnn()) {
    return at::mkldnn_linear(input, weight, *bias);
  }
#if defined(C10_MOBILE)
  if (xnnpack::use_linear(input, weight, *bias)) {
    return xnnpack::linear(input, weight, *bias);
  }
#endif
  if (input.dim() == 2 && bias->defined()) {
    // Fused op is marginally faster.
    return at::addmm(*bias, input, weight.t());
  }
  auto output = at::matmul(input, weight.t());
  if (bias->defined()) {
    output.add_(*bias);
  }
  return output;
}

很显然,linear 的 C++ 部分会调用到 at::addmm 或者 at::matmul,这里可能知道,对于二维的 Tensor,会调用 at::addmm,直接计算矩阵乘加,因为 Fused op 更快一些。而其他情况要调用 matmul,鉴于更多的深度学习模型的张量在二维以上,我们来重点看一下 at::matmul 又是如何被实现的,最后又怎样落到了 CUDA 层面。

我们继续循迹,发现在 pytorch/include/ATen/ops/matmul.h 层面

// aten::matmul(Tensor self, Tensor other) -> Tensor
inline at::Tensor matmul(const at::Tensor & self, const at::Tensor & other) {
    return at::_ops::matmul::call(self, other);
}

而真正实现 matmul 的地方,在这里 pytorch/aten/src/ATen/native/LinearAlgebra.cpp at main · pytorch/pytorch (github.com)

Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
  auto maybe_outnames = namedinference::compute_matmul_outnames(tensor1, tensor2);
  at::Tensor result, unused;
  result = at::native::_matmul_impl(unused, tensor1, tensor2);
  namedinference::propagate_names_if_nonempty(result, maybe_outnames);
  return result;
}

这里 torch 又套了一层 _matmul_impl,好在代码就在这个函数的上面,寻起来不费劲。pytorch/aten/src/ATen/native/LinearAlgebra.cpp at main · pytorch/pytorch (github.com) 粗略扫了一眼这里的逻辑,前半部分在对特殊情况做讨论,把维度小的尺寸小的张量踢出去调用其他 API,提升性能。后半部分如下:

    // ....
 } else {
    // dim_tensor1 >= 3 || dim_tensor2 >= 3
    // We track m1 vs m2 separately even though they must match for nicer error messages
    const int64_t n = dim_tensor1 > 1 ? tensor1.sizes().cend()[-2] : 1LL;
    const int64_t m1 = tensor1.sizes().back();
    auto batch_tensor1 = tensor1.sizes().slice(0, std::max<int64_t>(dim_tensor1 - 2, 0LL));
    const int64_t m2 = dim_tensor2 > 1 ? tensor2.sizes().cend()[-2] : tensor2.sizes().front();
    const int64_t p = dim_tensor2 > 1 ? tensor2.sizes().back() : 1LL;
    const IntArrayRef batch_tensor2(tensor2.sizes().data(),
                                    std::max<int64_t>(dim_tensor2 - 2, 0LL));

    // Same optimization for the gradients as that in should_fold
    // If we're going to broadcast we force it to go through the should_fold branch
    if (dim_tensor1 == 3 && dim_tensor2 == 3 && batch_tensor1[0] != batch_tensor2[0]) {
      if (batch_tensor1[0] == 1 && (tensor1.requires_grad() || isTensorSubclassLike(tensor1))) {
        return _matmul_impl(out, tensor1.squeeze(0), tensor2);
      }
      if (batch_tensor2[0] == 1 && (tensor2.requires_grad() || isTensorSubclassLike(tensor2))) {
        return _matmul_impl(out, tensor1, tensor2.squeeze(0));
      }
    }

    auto output_shape = infer_size_dimvector(batch_tensor1, batch_tensor2);
    const int64_t expand_batch_product = c10::multiply_integers(output_shape);

    // flatten expanded batches
    const auto tensor1_expand_size = [&output_shape, n, m1]{ DimVector ret(output_shape);
                                                             ret.append({n, m1});
                                                             return ret; }();
    const auto tensor1_expanded = tensor1.expand(tensor1_expand_size)
                                         .reshape({expand_batch_product, n, m1});
    // We need to treat the dim_tensor2 == 1 case separately as broadcasting would not convert
    // a vector of shape (n,) into a batch of matrices of shape (*, n, 1)
    auto vector_rhs = dim_tensor2 == 1;
    const auto tensor2_expand_size = [&output_shape, m2, p, vector_rhs]{
      DimVector ret(output_shape);
      if (vector_rhs) {
        ret.push_back(m2);
      } else {
        ret.append({m2, p});
      }
      return ret;
    }();
    auto tensor2_expanded = tensor2.expand(tensor2_expand_size);
    if (vector_rhs) {
      tensor2_expanded = tensor2_expanded.reshape({expand_batch_product, m2}).unsqueeze(2);
    } else {
      tensor2_expanded = tensor2_expanded.reshape({expand_batch_product, m2, p});
    }

    if (dim_tensor1 > 1) {
      output_shape.push_back(n);
    }
    if (dim_tensor2 > 1) {
      output_shape.push_back(p);
    }

    if (!has_out) {
      if (vector_rhs) {
        return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded).squeeze(-1), output_shape);
      } else {
        return at::_unsafe_view(tensor1_expanded.bmm(tensor2_expanded), output_shape);
      }
    } else {
      at::native::resize_output(out, output_shape);
      auto reshaped_out = out.reshape({expand_batch_product, n, p});
      at::bmm_out(reshaped_out, tensor1_expanded, tensor2_expanded);
      if (vector_rhs) {
        reshaped_out = reshaped_out.squeeze(-1);
      }
      if (!reshaped_out.is_alias_of(out)) {
        out.copy_(reshaped_out.view_as(out));
      }
      return out;
    }
  }

主要做了如下几件事:

  1. 变量初始化: 从张量中获取维度信息,并初始化一些变量(如n, m1, m2, p等),这些变量通常代表矩阵的维度。同时,也处理了batch维度。

  2. 优化处理 如果两个张量都是3维,并且它们的batch大小不匹配,则尝试进行优化。如果其中一个batch维度为1,并且对应的张量需要梯度计算或是某种特定类型的张量子类,则通过squeeze操作移除这个维度,并执行矩阵乘法。

  3. 形状推断和扩展: 推断输出张量的形状(output_shape);根据推断的形状,对输入张量进行 expand 和 reshape,以便它们可以进行批量矩阵乘法。这里考虑了多种情况,包括当tensor2是一个向量时的特殊处理。

  4. 执行矩阵乘法: 使用扩展和重塑后的张量执行批量矩阵乘法(bmm)。如果tensor2原本是一个向量,那么在乘法后会移除额外添加的维度。根据是否有输出张量(out)的指定,决定是直接返回结果还是将结果复制到输出张量中

也就是说,如果张量维度在二维以上,那么很可能我们最后回到了 at::bmm 和 at::bmm_out 这几个接口上,我们先看看它们对应的 pytorch 的功能是什么:

看来就是矩阵乘法!

但问题来了,bmm是怎么调用的?

How to find c++ source code of torch.bmm of pytorch - Stack Overflow 根据这个回答,我顺藤摸瓜找到了 pytorch2.0版本的情况:

- func: bmm(Tensor self, Tensor mat2) -> Tensor
  structured_delegate: bmm.out
  variants: function, method
  dispatch:
    SparseCPU: bmm_sparse_cpu
    SparseCUDA: bmm_sparse_cuda
    NestedTensorCPU: bmm_nested
    NestedTensorCUDA: bmm_nested_cuda
  tags: core

- func: bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)
  structured: True
  variants: function
  dispatch:
    CPU: bmm_out_cpu
    CUDA: bmm_out_cuda
    MPS: bmm_out_mps
    SparseCPU: bmm_out_sparse_cpu
    SparseCUDA: bmm_out_sparse_cuda
    SparseCsrCUDA: bmm_out_sparse_csr_cuda

也就是说,bmm_out 调用的CUDA名称是 bmm_out_cuda:

TORCH_IMPL_FUNC(bmm_out_cuda)(const Tensor& batch1, const Tensor& batch2, const Tensor &result) {
  Scalar beta(0.0);
  Scalar alpha(1.0);
  {
    NoNamesGuard guard;
    baddbmm_out_cuda_impl(result, result, batch1, batch2, beta, alpha);
  }
}

而它又会去使用 baddbmm_out_cuda_impl(真是魔鬼般复杂),baddbmm 是一个批量版本的addbmm 操作,b 表示batch。不过好在,这应该是最后一层:这里给个链接方便读者去找:pytorch/aten/src/ATen/native/cuda/Blas.cpp at v2.0.0 · pytorch/pytorch (github.com)

const Tensor& baddbmm_out_cuda_impl(const Tensor& result, const Tensor& self, const Tensor& batch1, const Tensor& batch2, const Scalar& beta, const Scalar& alpha) {
  IntArrayRef batch1_sizes = batch1.sizes();

  // handle pathological cases that blas may not like
  if (result.numel() == 0) {
    return result;
  } else if (batch1_sizes[2] == 0) {
    if (beta.to<c10::complex<double>>() == 0.0) {
      return result.zero_();
    } else {
      return result.mul_(beta);
    }
  }

  bool transpose_result = false;
  c10::MaybeOwned<Tensor> result_;
  IntArrayRef result_strides = result.strides();
  IntArrayRef result_sizes = result.sizes();

  if ((result_strides[1] == 1) &&
      ((result_sizes[2] == 1) || (result_strides[2] >= std::max<int64_t>(1, result_sizes[1])))) {
    result_ = resolve_conj_if_indicated(result, true);
  } else if ((result_strides[2] == 1) &&
    (result_sizes[1] == 1 || (result_strides[1] >= std::max<int64_t>(1, result_sizes[2])))) {
    transpose_result = true;
    result_ = resolve_conj_if_indicated(result, true);
  } else {
    result_ = c10::MaybeOwned<Tensor>::owned(result.transpose(1, 2).clone(at::MemoryFormat::Contiguous).transpose(1, 2));
  }

  int leading_dim = transpose_result ? 1 : 2;

  int64_t m = result_sizes[transpose_result ? 2 : 1];
  int64_t n = result_sizes[leading_dim];
  int64_t k = (transpose_result ? batch2 : batch1).sizes()[leading_dim];

  int64_t lda, ldb, ldc;
  bool transpose_batch1, transpose_batch2;
  auto batch1_ = prepare_batch_matrix_for_cublas(transpose_result ? batch2 : batch1, transpose_batch1, lda, transpose_result, m, k);
  auto batch2_ = prepare_batch_matrix_for_cublas(transpose_result ? batch1 : batch2, transpose_batch2, ldb, transpose_result, k, n);

  ldc = result_->strides()[leading_dim];
  int64_t num_batches = result_->sizes()[0];

  TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!result_->is_conj());

  AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, self.scalar_type(), "baddbmm_cuda", [&] {
    using opmath_t = at::opmath_type<scalar_t>;
    opmath_t alpha_val = alpha.to<opmath_t>();
    opmath_t beta_val = beta.to<opmath_t>();
    scalar_t* batch1_ptr = batch1_->data_ptr<scalar_t>();
    scalar_t* batch2_ptr = batch2_->data_ptr<scalar_t>();
    scalar_t* result_ptr = result_->data_ptr<scalar_t>();
    at::cuda::blas::bgemm<scalar_t>(
      transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n',
      transpose_batch2 ? batch2_->is_conj() ? 'c' : 't' : 'n',
      m, n, k,
      alpha_val,
      batch1_ptr, lda, batch1_->strides()[0],
      batch2_ptr, ldb, batch2_->strides()[0],
      beta_val,
      result_ptr, ldc, result_->strides()[0],
      num_batches
    );
  });
  if (!result.is_same(*result_)) {
    result.copy_(*result_);
  }
  return result;
}

 稍微分析一下这段代码:(By yiyan)

  1. 确定结果Tensor是否需要转置,以及是否需要复制和调整内存布局以保证连续的内存访问。
  2. 准备用于cuBLAS 的batch矩阵,这可能包括确定矩阵的维度、步长以及是否需要转置。
  3. 调用cuBLAS的bgemm函数执行批量矩阵乘法。这个函数可以在GPU上高效地执行批量矩阵乘法。
  4. 如果结果Tensor和内部使用的Tensor不是同一个,则将结果复制回结果Tensor。
  5. 返回结果Tensor。

至此,我们终于到达了 CUDA 的海岸,而前方的 cuBLAS 也是山穷水恶,极为复杂。

CUDA 部分

循着上一章节的踪迹,我们来到了 Blas.cpp 文件找到了 at::cuda::blas::bgemm<scalar_t> 这个函数,通过搜索,发现其出现在 pytorch/aten/src/ATen/cuda/CUDABlas.cpp at v2.0.0 · pytorch/pytorch (github.com) ,因为 cuda blas 对每个类型都分开实现了一个函数,因此针对不同类型,pytorch 这里也根据数据类型分了好几个函数pytorch/aten/src/ATen/cuda/CUDABlas.cpp at v2.0.0 · pytorch/pytorch (github.com)

以深度学习模型中常用的 FP16 类型为例:(仅截取了 CUDA 部分,ROCm 的就略去了)

template <>
void bgemm<at::Half>(CUDABLAS_BGEMM_ARGTYPES(at::Half)) {
  // See Note [Writing Nondeterministic Operations]
  globalContext().alertCuBLASConfigNotDeterministic();
  cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle();
  cublasOperation_t opa = _cublasOpFromChar(transa);
  cublasOperation_t opb = _cublasOpFromChar(transb);
  _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc);
  BGEMM_CHECK_ARGVALUES(at::Half);
  float falpha = alpha;
  float fbeta = beta;
  cudaDeviceProp* prop = at::cuda::getCurrentDeviceProperties();
  if (prop->major >= 5){
    TORCH_CUDABLAS_CHECK(cublasGemmStridedBatchedExFix(
      handle, opa, opb, m, n, k,
      (void*)(&falpha), a, CUDA_R_16F, lda, stridea,
      b, CUDA_R_16F, ldb, strideb, (void*)(&fbeta),
      c, CUDA_R_16F, ldc, stridec,
      num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP));
  } else {
    for (const auto i : c10::irange(num_batches)) {
      at::cuda::blas::gemm<at::Half>(
        transa, transb,
        m, n, k,
        alpha, (a + i * stridea), lda,
        (b + i * strideb), ldb, beta,
        (c + i * stridec), ldc);
    }
  }
}

逐句解释代码:(Written with GPT4)

  1. cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); 获取当前需要执行任何cuBLAS例程的cuBLAS句柄。

  2. cublasOperation_t opa = _cublasOpFromChar(transa);cublasOperation_t opb = _cublasOpFromChar(transb);transatransb 的字符(之前在 C++ 部分中 transpose_batch1 ? batch1_->is_conj() ? 'c' : 't' : 'n',)转换为cublasOperation_t类型,该类型定义了矩阵是否进行了转置操作,其中 c 表示共轭转置,n 表示不转置,t 表示转置。

  3. _cublasAdjustLdLevel3(transa, transb, m, n, k, &lda, &ldb, &ldc); 调整 BLAS Level 3例程的行列 stride:(lda, ldb, ldc)。这是为了正确读取内存中的数据。

  4. BGEMM_CHECK_ARGVALUES(at::Half); 是用于检查批量GEMM操作的参数的正确性的宏。

  5. 随后获取当前 CUDA 设备的软件适配情况和属性,并调用cublasGemmStridedBatchedExFix(...)这个cuBLasGemmStridedBatchedEx的变体允许混合精度计算,并且可以使用张量核心来更快地进行FP16计算。其中,gemm操作使用传递的参数:transatransb 进行转置操作,m, n, k 为矩阵尺寸,alphabeta 为标量值,以及指向矩阵 a, b, 和 c 的指针及其各自的领先维度和步长。

关于其调用的 cublas gemm 接口:1. Introduction — cuBLAS 12.3 documentation

  • 19
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值