【pytorch】op的c++和cuda编写(一)
需求:当你需要频繁的使用该自定义操作,或者调用很昂贵,或者需该操作要用到一些c/c++库。
c++扩展允许用户创建源外定义的PyTorch运算符,即与PyTorch后端分离的运算符。
c++扩展有两种形式:
- 使用setuptools提前构建
- 使用torch.utils.cpp_extension.load()即使构建
setuptools构建
- 写
setup.py
来让setuptools编译c++代码
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name='lltm_cpp',
ext_modules=[cpp_extension.CppExtension('lltm_cpp', ['lltm.cpp'])],# CppExtension是setuptools的便利包装.Extension传递正确的include路径并将扩展语言设置为c++
cmdclass={
'build_ext': cpp_extension.BuildExtension})# BuildExtension执行许多必需的配置步骤并检查,并且在混合c++ / CUDA扩展的情况下还管理混合编译
初步:编写c++操作
#include <torch/extension.h>
包括编写c++扩展所需的所有必需的pytorch bit
- ATen库:用于张量计算的主要API
- pybind11:为c++代码创建的python绑定方式
- headers:管理AT