1、编写:c++/cuda拓展源文件
pybind11_demo/
├── setup.py
├── example.cpp
└── test.py
example.cpp
#include <torch/extension.h>
#include <vector>
// Forward declaration of the function
torch::Tensor custom_add(torch::Tensor a, torch::Tensor b);
// The actual implementation
torch::Tensor custom_add(torch::Tensor a, torch::Tensor b) {
// Simple element-wise addition
return a + b;
}
// Pybind11 module definition
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("custom_add", &custom_add, "A function that adds two tensors");
}
PyTorch中的PYBIND11_MODULE
PYBIND11_MODULE
是Pybind11库中的一个宏,它用于定义一个Python模块,并将C++类、函数或其他对象绑定到该模块。这使得Python可以直接调用C++编写的函数和类,极大地提高了Python的性能,尤其是当计算密集型任务需要底层C++实现时。
2、编译:setuptools指导c++/cuda拓展的编译
setup.py
from setuptools import setup, Extension
from torch.utils.cpp_extension import CppExtension, BuildExtension,CUDAExtension
setup(
name='python_demo', # python包的名称
ext_modules=[
CppExtension(
name='demo', # 扩展模块名称,后面import使用
sources=['example.cpp'],
extra_compile_args={'CXX': ['-w', '-std=c++14']}
)
],
cmdclass={
'build_ext': BuildExtension
}
)
# python setup.py install
# or for development:
# python setup.py develop
指定构建命令
cmdclass={
'build_ext': BuildExtension
}
cmdclass是一个字典,用于指定自定义的构建命令。
'build_ext'是setuptools中的一个标准构建命令,用于构建扩展模块。
BuildExtension是PyTorch提供的BuildExtension类,它扩展了setuptools的build_ext命令,以支持C++和CUDA扩展的编译。
3、python调用编译完成的库
test.py
import torch
import demo # The name you specified in setup.py
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
result = demo.custom_add(a, b)
print(result) # Should output tensor([5., 7., 9.])
# python test.py
参考
https://zhuanlan.zhihu.com/p/459955492
深入解析PyTorch中的PYBIND11_MODULE:功能与实现_pytorch pybind11-CSDN博客