首先编写 our_add.cpp, 引入torch.h 定义加法
#include <torch/torch.h>
#include <vector>
#include <iostream>
torch::Tensor ADD(const torch::Tensor& a, const torch::Tensor& b){
return a+b;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m){
m.def("add", &ADD, "OUR_ADD add");
}
编写setup.py
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension
setup(
name = 'our_add',
version = '0.0.1',
ext_modules = [CppExtension('our_add',sources=['add.cpp'])],
cmdclass = {'build_ext': BuildExtension}
)
然后 cmd
python setup.py install
然后会进行编译及安装,完成后,就可以使用了.
import torch
import our_add
print(our_add.add)
x = torch.tensor(10)
y = torch.tensor(20)
z = our_add.add(x, y)
print(z)
-> <built-in method add of PyCapsule object at 0x7fa971e180f0>
-> tensor(30)
编译完成后,会在python环境中进行安装,如果需要移出来用,可以将包our_add-0.0.1-py3.7-linux-x86_64.egg 中的 .so与.py一起复制出来.
** 另 在调用时,先调用 torch ,再调用自己的包