pytorch关于c++的所有文档集合
Welcome to PyTorch Tutorials — PyTorch Tutorials 2.0.0+cu117 documentation
1.前置条件
使用编辑器clion,安装好cudatoolkit,cudnn,pytorch环境,编译工具gcc等等。
记得要设置好cudatoolkit的环境变量和动态链接库,这样到时候才能找到cudatoolkit和cudnn
安装教程可看
https://mp.csdn.net/mp_blog/creation/editor/new/129111146
注意我们如果要使用pytorch 的c语言版,是不需要安装额外的libpytorch的,因为pytorch下载的时候就自动整合了这些。
官方教程
CUDA projects | CLion Documentation
Installing C++ Distributions of PyTorch — PyTorch master documentation
2.通过clion创建cuda可执行项目
参照Installing C++ Distributions of PyTorch — PyTorch master documentation
这两个文件
以及 CMakeLists.txt我们是不需要的,我们使用setup.py代替 CMakeLists.txt
setup.py
参考官方文档
2. Writing the Setup Script — Python 3.6.15 documentation
以及pytorch的 setup.py教程,写的很详细
Custom C++ and CUDA Extensions — PyTorch Tutorials 2.0.0+cu117 documentation
文件项目结构
setup.py的安装代码模板
#python3 setup.py install
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
from distutils.sysconfig import get_config_vars
(opt,) = get_config_vars('OPT')
os.environ['OPT'] = " ".join(
flag for flag in opt.split() if flag != '-Wstrict-prototypes'
)
setup(
name='sptr',
ext_modules=[
CUDAExtension('sptr_cuda', [
'src/sptr/pointops_api.cpp',
'src/sptr/attention/attention_cuda.cpp',
'src/sptr/attention/attention_cuda_kernel.cu',
'src/sptr/precompute/precompute.cpp',
'src/sptr/precompute/precompute_cuda_kernel.cu',
'src/sptr/rpe/relative_pos_encoding_cuda.cpp',
'src/sptr/rpe/relative_pos_encoding_cuda_kernel.cu',
],
extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2', '-g', '-G']}
)
],
cmdclass={'build_ext': BuildExtension}
)
- setup的是一个包,要将什么包给安装上来,是我们要生成的动态链接库的名字
- name='sptr' 是包名,执行python3 setup.py install会安装一个叫sptr的包
- ext_modules表明我要输出的模块,模块才是真正能被python代码调用的,而不是包!,比如我写
- import sptr是找不到模块的,因为他根本就不是模块!,调用import sptr_cuda才会有效。
- CUDAExtension就是拓展模块,比如我有模块sptr_cuda,与他绑定的有哪些cpp文件我写过来,配合pointops_api.cpp(也就是第一行),可以将指定的cpp接口暴露给sptr_cuda模块,使得python代码可以调用。
- extra_compile_args 就是传给 gcc 的额外的编译参数,比方说你可以传一个 -std=c++11
这里c语言的编译器用的是cxx应该也就是gcc不知道为啥要叫做cxx,nvcc就是cu代码的编译器,它也可以编译c++语言。
'nvcc': ['-O2', '-g', '-G'] -O2参数含义O2该优化选项会牺牲部分编译速度,除了执行-O1所执行的所有优化之外,还会采用几乎所有的目标配置支持的优化算法,用以提高目标代码的运行速度。
-g,-G
NVCC, the NVIDIA CUDA compiler driver, provides a mechanism for generating the debugging information necessary for CUDA-GDB to work properly. The
-g -G
option pair must be passed to NVCC when an application is compiled for ease of debugging with CUDA-GDB; for example,也就是生成调试信息,只有nvcc 添加上这两个选项,后面才能链接生成可以被cuda-gdb调试的可执行文件
gcc -g只是编译器,在编译的时候,产生调试信息,通俗来讲是后面生成的可执行文件能够被gdb调试,如果不加-g的话 gdb是无法调试的。
GCC中-O1 -O2 -O3 优化的原理是什么? - 知乎
- cmdclass将BuildExtension类给传入了,
-
torch.utils.cpp_extension.BuildExtension(dist,** kw )
简单来说就是提供参数的,我们直接写就好了
自定义
setuptools
构建扩展。
setuptools.build_ext
子类负责传递所需的最小编译器参数(例如-std=c++11
)以及混合的C ++/CUDA
编译(以及一般对CUDA
文件的支持)。当使用
BuildExtension
时,它将提供一个用于extra_compile_args
(不是普通列表)的词典,通过语言(cxx
或cuda
)映射到参数列表提供给编译器。这样可以在混合编译期间为C ++
和CUDA
编译器提供不同的参数。
(opt,) = get_config_vars('OPT')
os.environ['OPT'] = " ".join(
flag for flag in opt.split() if flag != '-Wstrict-prototypes'
) #设置环境变量opt
目的:创建环境变量opt,里面是执行setup.py传入的默认参数
Wstrict-prototypes:确定是否为未指定参数类型声明或定义的函数发出警告
原先的opt为字符串'-DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes'
结果os.environ['OPT'] 为 '-DNDEBUG -g -fwrapv -O3 -Wall',将-Wstrict-prototypes去除了,其他和get_config_vars('OPT')一样,就是不发出这一种警告了。
执行安装
python3 setup.py install
可以看到我们安装好的sptr在和其他包相同的位置
包是叫sptr-0.0.0-py3.7-linux-x86_64.egg的文件夹
打开后就可以看到我们导出的模块了,我们import导入的就是sptr_cuda.py,然后他又指向动态链接库sptr_cuda.cpython-37m-x86_64-linux-gnu.so,他是我们编译好的动态链接库(就是在运行时去动态的找头文件对应的实现的编译内容),pycache文件就是sptr_cuda.py的对应pyc文件。
EGG-INFO文件夹下存储了一些包的相关信息,其中比如source文件夹就记录了源代码的名称
README.md
setup.py
sptr.egg-info/PKG-INFO
sptr.egg-info/SOURCES.txt
sptr.egg-info/dependency_links.txt
sptr.egg-info/top_level.txt
src/sptr/pointops_api.cpp
src/sptr/attention/attention_cuda.cpp
src/sptr/attention/attention_cuda_kernel.cu
src/sptr/precompute/precompute.cpp
src/sptr/precompute/precompute_cuda_kernel.cu
src/sptr/rpe/relative_pos_encoding_cuda.cpp
src/sptr/rpe/relative_pos_encoding_cuda_kernel.cu
test/test_attention_op_step1.py
test/test_attention_op_step2.py
test/test_precompute_all.py
test/test_relative_pos_encoding_op_step1.py
test/test_relative_pos_encoding_op_step1_all.py
test/test_relative_pos_encoding_op_step2.py
可以据此定位到项目的源代码的位置(可能之后的调试代码的定位也是基于这个原理)
头文件在include文件夹下,so文件在ld_library_path下,然后暴露接口(使用PYBIND11_MODULE),最终导出模块,此时python就可以调用模块的接口了,所以so文件也就是封装好的c语言函数或者类。python调用c++接口的步骤如下:python导入模块,这个模块在site-packages里被找到,比如叫sptr_cuda.py,sptr_cuda.py里代理了很多c++的函数,这些实现都在sptr_cuda.cpython-37m-x86_64-linux-gnu.so中,当python调用函数,就在这里进行寻找实现,so文件完成计算后返回给接口,python程序就得到返回值了。
pointops_api.cpp
可以将指定的cpp接口暴露给sptr_cuda模块,使得python代码可以调用。
pybind11 具体用法
参考
跟我一起学习pybind11 之一 - 腾讯云开发者社区-腾讯云
绑定简单函数
让我们以一个极度简单的函数来开始创建python绑定,函数完成两数相加并返回结果
int add(int i, int j)
{
return i + j;
}
为简单起见,我们将函数和绑定代码都放在example.cpp
这个文件中
#include <pybind11/pybind11.h>
namespace py = pybind11;
int add(int i, int j)
{
return i + j;
}
PYBIND11_MODULE(example, m)
{
m.doc() = "pybind11 example plugin"; // 可选的模块说明
m.def("add", &add, "A function which adds two numbers");
}
PYBIND11_MODULE()
宏函数将会创建一个函数,在由Python发起import
语句时该函数将会被调用(也就是生成模块)。模块名字“example”,由宏的第一个参数指定(千万不能出现引号),比如下面代码就传入sptr。第二个参数"m",定义了一个py::module
的变量,实际也就是我们调用的moudle,传入python的模块。
m.doc:定义该模块的模块文档
m.def:定义该模块的映射参数,函数py::module::def()
生成绑定代码,将add()
函数暴露给Python。
第一个参数"add",表示我以后要在python中通过 模块名.add来调用函数
第二个参数&add,是将函数add的地址值填过来了,确定绑定的函数。
第三个参数:是函数的说明文档
注意:仅仅只需要少量的代码就能完成C++到Python的绑定工作,所有关于函数参数、返回值的细节,将会被模板元编程自动推导出来!这种整体的方法和语法都借鉴了Boost.Python
,但是其底层实现是完全不同的。(也就是光写好这个文件就能完成自动映射,其他的细节我们不用多管)
我们项目中的使用示例
#include <torch/serialize/tensor.h>
#include <torch/extension.h>
#include "attention/attention_cuda_kernel.h"
#include "rpe/relative_pos_encoding_cuda_kernel.h"
#include "precompute/precompute_cuda_kernel.h"
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("attention_step1_forward_cuda", &attention_step1_forward_cuda, "attention_step1_forward_cuda");
m.def("attention_step1_backward_cuda", &attention_step1_backward_cuda, "attention_step1_backward_cuda");
m.def("attention_step2_forward_cuda", &attention_step2_forward_cuda, "attention_step2_forward_cuda");
m.def("attention_step2_backward_cuda", &attention_step2_backward_cuda, "attention_step2_backward_cuda");
m.def("precompute_all_cuda", &precompute_all_cuda, "precompute_all_cuda");
m.def("dot_prod_with_idx_forward_cuda", &dot_prod_with_idx_forward_cuda, "dot_prod_with_idx_forward_cuda");
m.def("dot_prod_with_idx_backward_cuda", &dot_prod_with_idx_backward_cuda, "dot_prod_with_idx_backward_cuda");
m.def("attention_step2_with_rel_pos_value_forward_cuda", &attention_step2_with_rel_pos_value_forward_cuda, "attention_step2_with_rel_pos_value_forward_cuda");
m.def("attention_step2_with_rel_pos_value_backward_cuda", &attention_step2_with_rel_pos_value_backward_cuda, "attention_step2_with_rel_pos_value_backward_cuda");
m.def("dot_prod_with_idx_all_forward_cuda", &dot_prod_with_idx_all_forward_cuda, "dot_prod_with_idx_all_forward_cuda");
}
注意头文件 #include <torch/extension.h>很万能(下面有他的源码),他包含了all.h,python.h,可以将很多头文件给导入进来,当然也包括PYBIND11_MODULE这个函数。#include <torch/serialize/tensor.h>我觉得可以不写。
CMakeLists.txt(这个不用看,只使用setup.py编译,用这个不知道如何导出python模块)
find_package(PythonInterp REQUIRED)
cmake_minimum_required(VERSION 3.10)
project(untitled LANGUAGES CUDA CXX)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
add_executable(untitled main.cu test1.cu pointops_api.cpp)
set(CMAKE_CUDA_STANDARD 17)
set_target_properties(untitled PROPERTIES
CUDA_SEPARABLE_COMPILATION ON)
include_directories(SYSTEM ${TORCH_INCLUDE_DIRS})
target_link_libraries(untitled CUDA "${TORCH_LIBRARIES}")
find_package(PythonInterp REQUIRED)
添加python编译器,否则cmake配置libtorch会报错Failed to compute shorthash for libnvrtc.so
cmake_minimum_required(VERSION 3.10)
要求使用的最低cmake版本,cmake低于这个版本不能编译该项目,可以自己设定
project(untitled LANGUAGES CUDA CXX)
untitled 为项目名,LANGUAGES CUDA CXX这个非常重要,就是我们的代码里的cu代码,和cpp,cc等代码能被正常编译,也就是同时启用CUDA代码和cxx代码的编译,
如果不加上CXX,就会报错
cmake-build-debug Unknown extension ".cc" for file
因为比如cpp文件cuda的编译器nvcc是可以编译的,但.cc文件也就是c++的源代码文件他无法编译,此时需要启用c++编译,也就是加上CXX。
find_package(Torch REQUIRED)
找到pytorch的c++文件,将pytorch导入进来。
这里是会先找pytorch的config文件,叫做 TorchConfig.cmake或torch-config.cmake,这个是pytorch关于cmake的配置文件,比如包含了去哪找pytorch的头文件,以及动态库等等,否则编译时是找不到对应的头文件的。
为了找到TorchConfig.cmake,我们需要设置一个缓存变量CMAKE_PREFIX_PATH,让他能够找到pytorch的TorchConfig.cmake的位置。
CMAKE_PREFIX_PATH=/home/zxy/mambaforge/envs/sphere/lib/python3.7/site-packages/torch/share/cmake
该路径可以由torch.utils.cmake_prefix_path查询到
CMAKE_PREFIX_PATH作用:
用于FIND_XXX()搜索的路径,并添加适当的后缀。
指定一个将被FIND_XXX()命令使用的路径。它包含了 "基础 "目录,FIND_XXX()命令将适当的子目录附加到基础目录中。因此,FIND_PROGRAM()在路径中的每个目录中添加/bin,FIND_LIBRARY()在每个目录中添加/lib,FIND_PATH()和FIND_FILE()添加/include。默认情况下,它是空的,它的目的是由项目来设置。参见CMAKE_SYSTEM_PREFIX_PATH, CMAKE_INCLUDE_PATH, CMAKE_LIBRARY_PATH, CMAKE_PROGRAM_PATH。
FIND_PROGRAM中变为torch.utils.cmake_prefix_path/bin
FIND_PATH中变为torch.utils.cmake_prefix_path/include
找torch包变为torch.utils.cmake_prefix_path/torch 这个正是我们需要的,此时就能正确找到torch了
cmake最终写为如下,用于添加缓存变量
cmake -D CMAKE_PREFIX_PATH=/home/zxy/mambaforge/envs/sphere/lib/python3.7/site-packages/torch/share/cmake
注意:如果下载了libpytorch(也就是单独的c++ pytorch库,不要将他的cmake文件夹导入进来,否则会报
Libtorch C++ build ‘Could NOT find Torch (missing: TORCH_LIBRARY)’
)
add_executable(untitled main.cu test1.cu pointops_api.cpp)
所有要进行编译的代码都在这声明。
include_directories(SYSTEM ${TORCH_INCLUDE_DIRS})
target_link_libraries(untitled CUDA "${TORCH_LIBRARIES}")
将pytorch头文件加入到头文件查找路径,将pytorch库文件添加到链接查找路径
cmake缓存变量
cmake缓存变量(Cache Variabl),相当于一个全局变量。在同一个CMake工程中任何地方都可以使用。
如何指定缓存变量?
- 法1 在调用cmake的时候加-D,后面的就是缓存变量
cmake -DCMAKE_PREFIX_PATH=/your/path
cmake -D CMAKE_PREFIX_PATH=/your/path
这两种都可以
- 法2 在clion中修改,这两个位置是同步的,修改其中的一个框就行,其实和法1是同一种方式
- 法3 使用set 命令
set(<variable> <value>... CACHE <type> <docstring> [FORCE]) |
- variable:变量名称
- value:变量值列表
- CACHE:cache变量的标志
- type:变量类型,取决于变量的值。类型分为:BOOL、FILEPATH、PATH、STRING、INTERNAL
- docstring:必须是字符串,作为变量概要说明
- FORCE:强制选项,强制修改变量值
-
代码结构
- learn_cmake:为根目录
- build:为CMake配置输出目录(在此例中即生成sln解决方案的地方)
- cmake_config.bat:执行CMake配置过程的脚本(双击直接运行)
- CMakeLists.txt:CMake脚本
-
示例代码(CMakeLists.txt文件内容)
cmake_minimum_required(VERSION 3.18)
# 设置工程名称
set(PROJECT_NAME KAIZEN)
# 设置工程版本号
set(PROJECT_VERSION "1.0.0.10" CACHE STRING "默认版本号")
# 工程定义
project(${PROJECT_NAME}
LANGUAGES CXX C
VERSION ${PROJECT_VERSION}
)
# 打印开始日志
message(STATUS "\n########## BEGIN_TEST_CACHE_VARIABLE")
### 定义缓存变量
# 定义一个STRIING类型缓存变量(不加FORCE选项)
set(MY_GLOBAL_VAR_STRING_NOFORCE "abcdef" CACHE STRING "定义一个STRING缓存变量")
message("MY_GLOBAL_VAR_STRING_NOFORCE: ${MY_GLOBAL_VAR_STRING_NOFORCE}")
# 定义一个STRIING类型缓存变量(加FORCE选项)
set(MY_GLOBAL_VAR_STRING "abc" CACHE STRING "定义一个STRING缓存变量" FORCE)
message("MY_GLOBAL_VAR_STRING: ${MY_GLOBAL_VAR_STRING}")
- 法4 在CMakeCache.txt中进行修改,注意这种的优先级比较低,就是使用命令行定义的变量会覆盖CMakeCache.txt的同名变量,可以说是命令行定义会覆盖CMakeCache.txt的值,每次运行cmake,比如命令行传入了CMAKE_PREFIX_PATH为aaa,那么会先修改CMakeCache.txt的CMAKE_PREFIX_PATH为aaa,再读入CMakeCache.txt的总体缓存数据。覆盖说明我在CMakeCache.txt定义的值,如果在命令行定义过了比如aaa,无论再怎么在CMakeCache.txt里修改都没有用,修改成bbb,ccc,运行一次cmake直接被改写成aaa。
当 CMake 首次在一个空的构建树中运行时,它会创建一个 CMakeCache.txt
文件并使用项目的可自定义设置填充它。此选项可用于指定优先于项目默认值的设置。可以根据需要为尽可能多的CACHE
条目重复该选项。
CMakeCache.txt文件示例
//Path to a program.
CMAKE_OBJCOPY:FILEPATH=/usr/bin/objcopy
//Path to a program.
CMAKE_OBJDUMP:FILEPATH=/usr/bin/objdump
//No help, variable specified on the command line.
CMAKE_PREFIX_PATH:UNINITIALIZED=/home/zxy/mambaforge/envs/sphere/lib/python3.7/site-packages/torch/share/cmake
//Value Computed by CMake
CMAKE_PROJECT_DESCRIPTION:STATIC=
//Value Computed by CMake
CMAKE_PROJECT_HOMEPAGE_URL:STATIC=
可以通过一下的形式查看变量,就写在txt文件中
message("================${CMAKE_CXX_FLAGS}===============")
附录:头文件
all.h
#pragma once
#if !defined(_MSC_VER) && __cplusplus < 201402L
#error C++14 or later compatible compiler is required to use PyTorch.
#endif
#include <torch/cuda.h>
#include <torch/data.h>
#include <torch/enum.h>
#include <torch/fft.h>
#include <torch/jit.h>
#include <torch/linalg.h>
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/serialize.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <torch/autograd.h>
#include <torch/version.h>
extension.h,万能头文件,一个文件包含了所有要用的东西。
#pragma once
// All pure C++ headers for the C++ frontend.
#include <torch/all.h>
// Python bindings for the C++ frontend (includes Python.h).
#include <torch/python.h>
python.h
#pragma once
#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/ordered_dict.h>
#include <torch/types.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>
#include <iterator>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
namespace torch {
namespace python {
namespace detail {
inline Device py_object_to_device(py::object object) {
PyObject* obj = object.ptr();
if (THPDevice_Check(obj)) {
return reinterpret_cast<THPDevice*>(obj)->device;
}
throw TypeError("Expected device");
}
inline Dtype py_object_to_dtype(py::object object) {
PyObject* obj = object.ptr();
if (THPDtype_Check(obj)) {
return reinterpret_cast<THPDtype*>(obj)->scalar_type;
}
throw TypeError("Expected dtype");
}
template <typename ModuleType>
using PyModuleClass =
py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;
/// Dynamically creates a subclass of `torch.nn.cpp.ModuleWrapper` that is also
/// a subclass of `torch.nn.Module`, and passes it the user-provided C++ module
/// to which it delegates all calls.
template <typename ModuleType>
void bind_cpp_module_wrapper(
py::module module,
PyModuleClass<ModuleType> cpp_class,
const char* name) {
// Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass
// with a dynamically created class below.
py::object cpp_module =
py::module::import("torch.nn.cpp").attr("ModuleWrapper");
// Grab the `type` class which we'll use as a metaclass to create a new class
// dynamically.
py::object type_metaclass =
py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);
// The `ModuleWrapper` constructor copies all functions to its own `__dict__`
// in its constructor, but we do need to give our dynamic class a constructor.
// Inside, we construct an instance of the original C++ module we're binding
// (the `torch::nn::Module` subclass), and then forward it to the
// `ModuleWrapper` constructor.
py::dict attributes;
// `type()` always needs a `str`, but pybind11's `str()` method always creates
// a `unicode` object.
py::object name_str = py::str(name);
// Dynamically create the subclass of `ModuleWrapper`, which is a subclass of
// `torch.nn.Module`, and will delegate all calls to the C++ module we're
// binding.
py::object wrapper_class =
type_metaclass(name_str, py::make_tuple(cpp_module), attributes);
// The constructor of the dynamic class calls `ModuleWrapper.__init__()`,
// which replaces its methods with those of the C++ module.
wrapper_class.attr("__init__") = py::cpp_function(
[cpp_module, cpp_class](
py::object self, py::args args, py::kwargs kwargs) {
cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs));
},
py::is_method(wrapper_class));
// Calling `my_module.my_class` now means that `my_class` is a subclass of
// `ModuleWrapper`, and whose methods call into the C++ module we're binding.
module.attr(name) = wrapper_class;
}
} // namespace detail
/// Adds method bindings for a pybind11 `class_` that binds an `nn::Module`
/// subclass.
///
/// Say you have a pybind11 class object created with `py::class_<Net>(m,
/// "Net")`. This function will add all the necessary `.def()` calls to bind the
/// `nn::Module` base class' methods, such as `train()`, `eval()` etc. into
/// Python.
///
/// Users should prefer to use `bind_module` if possible.
template <typename ModuleType, typename... Extra>
py::class_<ModuleType, Extra...> add_module_bindings(
py::class_<ModuleType, Extra...> module) {
// clang-format off
return module
.def("train",
[](ModuleType& module, bool mode) { module.train(mode); },
py::arg("mode") = true)
.def("eval", [](ModuleType& module) { module.eval(); })
.def("clone", [](ModuleType& module) { return module.clone(); })
.def_property_readonly(
"training", [](ModuleType& module) { return module.is_training(); })
.def("zero_grad", [](ModuleType& module) { module.zero_grad(); })
.def_property_readonly( "_parameters", [](ModuleType& module) {
return module.named_parameters(/*recurse=*/false);
})
.def("parameters", [](ModuleType& module, bool recurse) {
return module.parameters(recurse);
},
py::arg("recurse") = true)
.def("named_parameters", [](ModuleType& module, bool recurse) {
return module.named_parameters(recurse);
},
py::arg("recurse") = true)
.def_property_readonly("_buffers", [](ModuleType& module) {
return module.named_buffers(/*recurse=*/false);
})
.def("buffers", [](ModuleType& module, bool recurse) {
return module.buffers(recurse); },
py::arg("recurse") = true)
.def("named_buffers", [](ModuleType& module, bool recurse) {
return module.named_buffers(recurse);
},
py::arg("recurse") = true)
.def_property_readonly(
"_modules", [](ModuleType& module) { return module.named_children(); })
.def("modules", [](ModuleType& module) { return module.modules(); })
.def("named_modules",
[](ModuleType& module, py::object /* unused */, std::string prefix) {
return module.named_modules(std::move(prefix));
},
py::arg("memo") = py::none(),
py::arg("prefix") = std::string())
.def("children", [](ModuleType& module) { return module.children(); })
.def("named_children",
[](ModuleType& module) { return module.named_children(); })
.def("to", [](ModuleType& module, py::object object, bool non_blocking) {
if (THPDevice_Check(object.ptr())) {
module.to(
reinterpret_cast<THPDevice*>(object.ptr())->device,
non_blocking);
} else {
module.to(detail::py_object_to_dtype(object), non_blocking);
}
},
py::arg("dtype_or_device"),
py::arg("non_blocking") = false)
.def("to",
[](ModuleType& module,
py::object device,
py::object dtype,
bool non_blocking) {
if (device.is_none()) {
module.to(detail::py_object_to_dtype(dtype), non_blocking);
} else if (dtype.is_none()) {
module.to(detail::py_object_to_device(device), non_blocking);
} else {
module.to(
detail::py_object_to_device(device),
detail::py_object_to_dtype(dtype),
non_blocking);
}
},
py::arg("device"),
py::arg("dtype"),
py::arg("non_blocking") = false)
.def("cuda", [](ModuleType& module) { module.to(kCUDA); })
.def("cpu", [](ModuleType& module) { module.to(kCPU); })
.def("float", [](ModuleType& module) { module.to(kFloat32); })
.def("double", [](ModuleType& module) { module.to(kFloat64); })
.def("half", [](ModuleType& module) { module.to(kFloat16); })
.def("__str__", [](ModuleType& module) { return module.name(); })
.def("__repr__", [](ModuleType& module) { return module.name(); });
// clang-format on
}
/// Creates a pybind11 class object for an `nn::Module` subclass type and adds
/// default bindings.
///
/// After adding the default bindings, the class object is returned, such that
/// you can add more bindings.
///
/// Example usage:
/// \rst
/// .. code-block:: cpp
///
/// struct Net : torch::nn::Module {
/// Net(int in, int out) { }
/// torch::Tensor forward(torch::Tensor x) { return x; }
/// };
///
/// PYBIND11_MODULE(my_module, m) {
/// torch::python::bind_module<Net>(m, "Net")
/// .def(py::init<int, int>())
/// .def("forward", &Net::forward);
/// }
/// \endrst
template <typename ModuleType, bool force_enable = false>
torch::disable_if_t<
torch::detail::has_forward<ModuleType>::value && !force_enable,
detail::PyModuleClass<ModuleType>>
bind_module(py::module module, const char* name) {
py::module cpp = module.def_submodule("cpp");
auto cpp_class =
add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
detail::bind_cpp_module_wrapper(module, cpp_class, name);
return cpp_class;
}
/// Creates a pybind11 class object for an `nn::Module` subclass type and adds
/// default bindings.
///
/// After adding the default bindings, the class object is returned, such that
/// you can add more bindings.
///
/// If the class has a `forward()` method, it is automatically exposed as
/// `forward()` and `__call__` in Python.
///
/// Example usage:
/// \rst
/// .. code-block:: cpp
///
/// struct Net : torch::nn::Module {
/// Net(int in, int out) { }
/// torch::Tensor forward(torch::Tensor x) { return x; }
/// };
///
/// PYBIND11_MODULE(my_module, m) {
/// torch::python::bind_module<Net>(m, "Net")
/// .def(py::init<int, int>())
/// .def("forward", &Net::forward);
/// }
/// \endrst
template <
typename ModuleType,
typename =
torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
detail::PyModuleClass<ModuleType> bind_module(
py::module module,
const char* name) {
return bind_module<ModuleType, /*force_enable=*/true>(module, name)
.def("forward", &ModuleType::forward)
.def("__call__", &ModuleType::forward);
}
} // namespace python
} // namespace torch