自己的完整c++ cuda包

pytorch关于c++的所有文档集合

Welcome to PyTorch Tutorials — PyTorch Tutorials 2.0.0+cu117 documentation

1.前置条件

使用编辑器clion,安装好cudatoolkit,cudnn,pytorch环境,编译工具gcc等等。

记得要设置好cudatoolkit的环境变量和动态链接库,这样到时候才能找到cudatoolkit和cudnn

安装教程可看

https://mp.csdn.net/mp_blog/creation/editor/new/129111146

注意我们如果要使用pytorch 的c语言版,是不需要安装额外的libpytorch的,因为pytorch下载的时候就自动整合了这些。

官方教程

CUDA projects | CLion Documentation

Installing C++ Distributions of PyTorch — PyTorch master documentation

2.通过clion创建cuda可执行项目

参照Installing C++ Distributions of PyTorch — PyTorch master documentation

 这两个文件

以及 CMakeLists.txt我们是不需要的,我们使用setup.py代替 CMakeLists.txt

setup.py

参考官方文档

2. Writing the Setup Script — Python 3.6.15 documentation

以及pytorch的 setup.py教程,写的很详细

Custom C++ and CUDA Extensions — PyTorch Tutorials 2.0.0+cu117 documentation

文件项目结构

setup.py的安装代码模板

#python3 setup.py install
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import os
from distutils.sysconfig import get_config_vars

(opt,) = get_config_vars('OPT')
os.environ['OPT'] = " ".join(
    flag for flag in opt.split() if flag != '-Wstrict-prototypes'
)

setup(
    name='sptr',
    ext_modules=[
        CUDAExtension('sptr_cuda', [
            'src/sptr/pointops_api.cpp',
            'src/sptr/attention/attention_cuda.cpp',
            'src/sptr/attention/attention_cuda_kernel.cu',
            'src/sptr/precompute/precompute.cpp',
            'src/sptr/precompute/precompute_cuda_kernel.cu',
            'src/sptr/rpe/relative_pos_encoding_cuda.cpp',
            'src/sptr/rpe/relative_pos_encoding_cuda_kernel.cu',
            ],
        extra_compile_args={'cxx': ['-g'], 'nvcc': ['-O2', '-g', '-G']}
        )
    ],
    cmdclass={'build_ext': BuildExtension}
)
  • setup的是一个包,要将什么包给安装上来,是我们要生成的动态链接库的名字
  • name='sptr' 是包名,执行python3 setup.py install会安装一个叫sptr的包
  • ext_modules表明我要输出的模块,模块才是真正能被python代码调用的,而不是包!,比如我写
  • import sptr是找不到模块的,因为他根本就不是模块!,调用import  sptr_cuda才会有效。
  • CUDAExtension就是拓展模块,比如我有模块sptr_cuda,与他绑定的有哪些cpp文件我写过来,配合pointops_api.cpp(也就是第一行),可以将指定的cpp接口暴露给sptr_cuda模块,使得python代码可以调用。
  • extra_compile_args 就是传给 gcc 的额外的编译参数,比方说你可以传一个 -std=c++11

这里c语言的编译器用的是cxx应该也就是gcc不知道为啥要叫做cxx,nvcc就是cu代码的编译器,它也可以编译c++语言。

       

         'nvcc': ['-O2', '-g', '-G'] -O2参数含义O2该优化选项会牺牲部分编译速度,除了执行-O1所执行的所有优化之外,还会采用几乎所有的目标配置支持的优化算法,用以提高目标代码的运行速度。

        -g,-G

NVCC, the NVIDIA CUDA compiler driver, provides a mechanism for generating the debugging information necessary for CUDA-GDB to work properly. The -g -G option pair must be passed to NVCC when an application is compiled for ease of debugging with CUDA-GDB; for example,

        也就是生成调试信息,只有nvcc 添加上这两个选项,后面才能链接生成可以被cuda-gdb调试的可执行文件

        gcc -g只是编译器,在编译的时候,产生调试信息,通俗来讲是后面生成的可执行文件能够被gdb调试,如果不加-g的话 gdb是无法调试的。

        

GCC中-O1 -O2 -O3 优化的原理是什么? - 知乎

  • cmdclass将BuildExtension类给传入了,
  • torch.utils.cpp_extension.BuildExtension(dist,** kw )

简单来说就是提供参数的,我们直接写就好了

自定义setuptools构建扩展。

setuptools.build_ext子类负责传递所需的最小编译器参数(例如-std=c++11)以及混合的C ++/CUDA编译(以及一般对CUDA文件的支持)。

当使用BuildExtension时,它将提供一个用于extra_compile_args(不是普通列表)的词典,通过语言(cxxcuda)映射到参数列表提供给编译器。这样可以在混合编译期间为C ++CUDA编译器提供不同的参数。

(opt,) = get_config_vars('OPT')
os.environ['OPT'] = " ".join(
    flag for flag in opt.split() if flag != '-Wstrict-prototypes'
) #设置环境变量opt

目的:创建环境变量opt,里面是执行setup.py传入的默认参数

Wstrict-prototypes:确定是否为未指定参数类型声明或定义的函数发出警告

原先的opt为字符串'-DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes'

结果os.environ['OPT'] 为   '-DNDEBUG -g -fwrapv -O3 -Wall',将-Wstrict-prototypes去除了,其他和get_config_vars('OPT')一样,就是不发出这一种警告了。

执行安装

python3 setup.py install

可以看到我们安装好的sptr在和其他包相同的位置

 包是叫sptr-0.0.0-py3.7-linux-x86_64.egg的文件夹

 打开后就可以看到我们导出的模块了,我们import导入的就是sptr_cuda.py,然后他又指向动态链接库sptr_cuda.cpython-37m-x86_64-linux-gnu.so,他是我们编译好的动态链接库(就是在运行时去动态的找头文件对应的实现的编译内容),pycache文件就是sptr_cuda.py的对应pyc文件。

 EGG-INFO文件夹下存储了一些包的相关信息,其中比如source文件夹就记录了源代码的名称

README.md
setup.py
sptr.egg-info/PKG-INFO
sptr.egg-info/SOURCES.txt
sptr.egg-info/dependency_links.txt
sptr.egg-info/top_level.txt
src/sptr/pointops_api.cpp
src/sptr/attention/attention_cuda.cpp
src/sptr/attention/attention_cuda_kernel.cu
src/sptr/precompute/precompute.cpp
src/sptr/precompute/precompute_cuda_kernel.cu
src/sptr/rpe/relative_pos_encoding_cuda.cpp
src/sptr/rpe/relative_pos_encoding_cuda_kernel.cu
test/test_attention_op_step1.py
test/test_attention_op_step2.py
test/test_precompute_all.py
test/test_relative_pos_encoding_op_step1.py
test/test_relative_pos_encoding_op_step1_all.py
test/test_relative_pos_encoding_op_step2.py

可以据此定位到项目的源代码的位置(可能之后的调试代码的定位也是基于这个原理)

头文件在include文件夹下,so文件在ld_library_path下,然后暴露接口(使用PYBIND11_MODULE),最终导出模块,此时python就可以调用模块的接口了,所以so文件也就是封装好的c语言函数或者类。python调用c++接口的步骤如下:python导入模块,这个模块在site-packages里被找到,比如叫sptr_cuda.py,sptr_cuda.py里代理了很多c++的函数,这些实现都在sptr_cuda.cpython-37m-x86_64-linux-gnu.so中,当python调用函数,就在这里进行寻找实现,so文件完成计算后返回给接口,python程序就得到返回值了。

pointops_api.cpp

可以将指定的cpp接口暴露给sptr_cuda模块,使得python代码可以调用。

pybind11 具体用法

参考

跟我一起学习pybind11 之一 - 腾讯云开发者社区-腾讯云

绑定简单函数

让我们以一个极度简单的函数来开始创建python绑定,函数完成两数相加并返回结果

int add(int i, int j)
{
    return i + j;
}

为简单起见,我们将函数和绑定代码都放在example.cpp这个文件中

#include <pybind11/pybind11.h>
namespace py = pybind11;

int add(int i, int j)
{
    return i + j;
}

PYBIND11_MODULE(example, m)
{
    m.doc() = "pybind11 example plugin"; // 可选的模块说明

    m.def("add", &add, "A function which adds two numbers");
}

PYBIND11_MODULE()宏函数将会创建一个函数,在由Python发起import语句时该函数将会被调用(也就是生成模块)。模块名字“example”由宏的第一个参数指定(千万不能出现引号),比如下面代码就传入sptr第二个参数"m",定义了一个py::module的变量,实际也就是我们调用的moudle,传入python的模块。

m.doc:定义该模块的模块文档

m.def:定义该模块的映射参数,函数py::module::def()生成绑定代码,将add()函数暴露给Python。

第一个参数"add",表示我以后要在python中通过 模块名.add来调用函数

第二个参数&add,是将函数add的地址值填过来了,确定绑定的函数。

第三个参数:是函数的说明文档

注意:仅仅只需要少量的代码就能完成C++到Python的绑定工作,所有关于函数参数、返回值的细节,将会被模板元编程自动推导出来!这种整体的方法和语法都借鉴了Boost.Python,但是其底层实现是完全不同的。(也就是光写好这个文件就能完成自动映射,其他的细节我们不用多管

我们项目中的使用示例

#include <torch/serialize/tensor.h>
#include <torch/extension.h>

#include "attention/attention_cuda_kernel.h"
#include "rpe/relative_pos_encoding_cuda_kernel.h"
#include "precompute/precompute_cuda_kernel.h"


PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
    m.def("attention_step1_forward_cuda", &attention_step1_forward_cuda, "attention_step1_forward_cuda");
    m.def("attention_step1_backward_cuda", &attention_step1_backward_cuda, "attention_step1_backward_cuda");
    m.def("attention_step2_forward_cuda", &attention_step2_forward_cuda, "attention_step2_forward_cuda");
    m.def("attention_step2_backward_cuda", &attention_step2_backward_cuda, "attention_step2_backward_cuda");
    m.def("precompute_all_cuda", &precompute_all_cuda, "precompute_all_cuda");
    m.def("dot_prod_with_idx_forward_cuda", &dot_prod_with_idx_forward_cuda, "dot_prod_with_idx_forward_cuda");
    m.def("dot_prod_with_idx_backward_cuda", &dot_prod_with_idx_backward_cuda, "dot_prod_with_idx_backward_cuda");
    m.def("attention_step2_with_rel_pos_value_forward_cuda", &attention_step2_with_rel_pos_value_forward_cuda, "attention_step2_with_rel_pos_value_forward_cuda");
    m.def("attention_step2_with_rel_pos_value_backward_cuda", &attention_step2_with_rel_pos_value_backward_cuda, "attention_step2_with_rel_pos_value_backward_cuda");
    m.def("dot_prod_with_idx_all_forward_cuda", &dot_prod_with_idx_all_forward_cuda, "dot_prod_with_idx_all_forward_cuda");
}

注意头文件 #include <torch/extension.h>很万能(下面有他的源码),他包含了all.h,python.h,可以将很多头文件给导入进来,当然也包括PYBIND11_MODULE这个函数。#include <torch/serialize/tensor.h>我觉得可以不写。

 CMakeLists.txt(这个不用看,只使用setup.py编译,用这个不知道如何导出python模块)

find_package(PythonInterp REQUIRED)
cmake_minimum_required(VERSION 3.10)
project(untitled LANGUAGES CUDA CXX)
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")

add_executable(untitled main.cu test1.cu pointops_api.cpp)
set(CMAKE_CUDA_STANDARD 17)



set_target_properties(untitled PROPERTIES
        CUDA_SEPARABLE_COMPILATION ON)


include_directories(SYSTEM ${TORCH_INCLUDE_DIRS})
target_link_libraries(untitled CUDA "${TORCH_LIBRARIES}")

find_package(PythonInterp REQUIRED)

添加python编译器,否则cmake配置libtorch会报错Failed to compute shorthash for libnvrtc.so

cmake_minimum_required(VERSION 3.10)

要求使用的最低cmake版本,cmake低于这个版本不能编译该项目,可以自己设定

project(untitled LANGUAGES CUDA CXX)

untitled 为项目名,LANGUAGES CUDA CXX这个非常重要,就是我们的代码里的cu代码,和cpp,cc等代码能被正常编译,也就是同时启用CUDA代码和cxx代码的编译,

如果不加上CXX,就会报错

cmake-build-debug Unknown extension ".cc" for file

因为比如cpp文件cuda的编译器nvcc是可以编译的,但.cc文件也就是c++的源代码文件他无法编译,此时需要启用c++编译,也就是加上CXX。

find_package(Torch REQUIRED)

找到pytorch的c++文件,将pytorch导入进来。

这里是会先找pytorch的config文件,叫做    TorchConfig.cmake或torch-config.cmake,这个是pytorch关于cmake的配置文件,比如包含了去哪找pytorch的头文件,以及动态库等等,否则编译时是找不到对应的头文件的。

为了找到TorchConfig.cmake,我们需要设置一个缓存变量CMAKE_PREFIX_PATH,让他能够找到pytorch的TorchConfig.cmake的位置。

CMAKE_PREFIX_PATH=/home/zxy/mambaforge/envs/sphere/lib/python3.7/site-packages/torch/share/cmake

该路径可以由torch.utils.cmake_prefix_path查询到

 CMAKE_PREFIX_PATH作用:

用于FIND_XXX()搜索的路径,并添加适当的后缀。

指定一个将被FIND_XXX()命令使用的路径。它包含了 "基础 "目录,FIND_XXX()命令将适当的子目录附加到基础目录中。因此,FIND_PROGRAM()在路径中的每个目录中添加/bin,FIND_LIBRARY()在每个目录中添加/lib,FIND_PATH()和FIND_FILE()添加/include。默认情况下,它是空的,它的目的是由项目来设置。参见CMAKE_SYSTEM_PREFIX_PATH, CMAKE_INCLUDE_PATH, CMAKE_LIBRARY_PATH, CMAKE_PROGRAM_PATH。
FIND_PROGRAM中变为torch.utils.cmake_prefix_path/bin

FIND_PATH中变为torch.utils.cmake_prefix_path/include

找torch包变为torch.utils.cmake_prefix_path/torch 这个正是我们需要的,此时就能正确找到torch了

cmake最终写为如下,用于添加缓存变量

cmake -D CMAKE_PREFIX_PATH=/home/zxy/mambaforge/envs/sphere/lib/python3.7/site-packages/torch/share/cmake

注意:如果下载了libpytorch(也就是单独的c++ pytorch库,不要将他的cmake文件夹导入进来,否则会报

Libtorch C++ build ‘Could NOT find Torch (missing: TORCH_LIBRARY)’

add_executable(untitled main.cu test1.cu pointops_api.cpp)

所有要进行编译的代码都在这声明。

include_directories(SYSTEM ${TORCH_INCLUDE_DIRS})

target_link_libraries(untitled CUDA "${TORCH_LIBRARIES}")

将pytorch头文件加入到头文件查找路径,将pytorch库文件添加到链接查找路径

cmake缓存变量

cmake缓存变量(Cache Variabl),相当于一个全局变量。在同一个CMake工程中任何地方都可以使用。

如何指定缓存变量?

  • 法1 在调用cmake的时候加-D,后面的就是缓存变量
cmake -DCMAKE_PREFIX_PATH=/your/path
cmake -D CMAKE_PREFIX_PATH=/your/path

这两种都可以

  • 法2 在clion中修改,这两个位置是同步的,修改其中的一个框就行,其实和法1是同一种方式

  • 法3 使用set 命令
set(<variable> <value>... CACHE <type> <docstring> [FORCE])
  • variable:变量名称
  • value:变量值列表
  • CACHE:cache变量的标志
  • type:变量类型,取决于变量的值。类型分为:BOOL、FILEPATH、PATH、STRING、INTERNAL
  • docstring:必须是字符串,作为变量概要说明
  • FORCE:强制选项,强制修改变量值
  • 代码结构

    • learn_cmake:为根目录
    • build:为CMake配置输出目录(在此例中即生成sln解决方案的地方)
    • cmake_config.bat:执行CMake配置过程的脚本(双击直接运行)
    • CMakeLists.txt:CMake脚本
  • 示例代码(CMakeLists.txt文件内容)


	cmake_minimum_required(VERSION 3.18)

	


	# 设置工程名称

	set(PROJECT_NAME KAIZEN)

	


	# 设置工程版本号

	set(PROJECT_VERSION "1.0.0.10" CACHE STRING "默认版本号")

	


	# 工程定义

	project(${PROJECT_NAME}

	LANGUAGES CXX C

	VERSION ${PROJECT_VERSION}

	)

	


	# 打印开始日志

	message(STATUS "\n########## BEGIN_TEST_CACHE_VARIABLE")


	### 定义缓存变量



	# 定义一个STRIING类型缓存变量(不加FORCE选项)

	set(MY_GLOBAL_VAR_STRING_NOFORCE "abcdef" CACHE STRING "定义一个STRING缓存变量")

	message("MY_GLOBAL_VAR_STRING_NOFORCE: ${MY_GLOBAL_VAR_STRING_NOFORCE}")



	# 定义一个STRIING类型缓存变量(加FORCE选项)

	set(MY_GLOBAL_VAR_STRING "abc" CACHE STRING "定义一个STRING缓存变量" FORCE)

	message("MY_GLOBAL_VAR_STRING: ${MY_GLOBAL_VAR_STRING}")

  •  法4 在CMakeCache.txt中进行修改,注意这种的优先级比较低,就是使用命令行定义的变量会覆盖CMakeCache.txt的同名变量,可以说是命令行定义会覆盖CMakeCache.txt的值,每次运行cmake,比如命令行传入了CMAKE_PREFIX_PATH为aaa,那么会先修改CMakeCache.txt的CMAKE_PREFIX_PATH为aaa,再读入CMakeCache.txt的总体缓存数据。覆盖说明我在CMakeCache.txt定义的值,如果在命令行定义过了比如aaa,无论再怎么在CMakeCache.txt里修改都没有用,修改成bbb,ccc,运行一次cmake直接被改写成aaa。

当 CMake 首次在一个空的构建树中运行时,它会创建一个 CMakeCache.txt文件并使用项目的可自定义设置填充它。此选项可用于指定优先于项目默认值的设置。可以根据需要为尽可能多的CACHE条目重复该选项。

CMakeCache.txt文件示例

//Path to a program.
CMAKE_OBJCOPY:FILEPATH=/usr/bin/objcopy

//Path to a program.
CMAKE_OBJDUMP:FILEPATH=/usr/bin/objdump

//No help, variable specified on the command line.
CMAKE_PREFIX_PATH:UNINITIALIZED=/home/zxy/mambaforge/envs/sphere/lib/python3.7/site-packages/torch/share/cmake

//Value Computed by CMake
CMAKE_PROJECT_DESCRIPTION:STATIC=

//Value Computed by CMake
CMAKE_PROJECT_HOMEPAGE_URL:STATIC=

可以通过一下的形式查看变量,就写在txt文件中

message("================${CMAKE_CXX_FLAGS}===============")

附录:头文件

all.h

#pragma once

#if !defined(_MSC_VER) && __cplusplus < 201402L
#error C++14 or later compatible compiler is required to use PyTorch.
#endif

#include <torch/cuda.h>
#include <torch/data.h>
#include <torch/enum.h>
#include <torch/fft.h>
#include <torch/jit.h>
#include <torch/linalg.h>
#include <torch/nn.h>
#include <torch/optim.h>
#include <torch/serialize.h>
#include <torch/types.h>
#include <torch/utils.h>
#include <torch/autograd.h>
#include <torch/version.h>

extension.h,万能头文件,一个文件包含了所有要用的东西。

#pragma once

// All pure C++ headers for the C++ frontend.
#include <torch/all.h>
// Python bindings for the C++ frontend (includes Python.h).
#include <torch/python.h>

python.h

#pragma once

#include <torch/detail/static.h>
#include <torch/nn/module.h>
#include <torch/ordered_dict.h>
#include <torch/types.h>

#include <torch/csrc/Device.h>
#include <torch/csrc/Dtype.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/python_headers.h>
#include <torch/csrc/utils/pybind.h>

#include <iterator>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>

namespace torch {
namespace python {
namespace detail {
inline Device py_object_to_device(py::object object) {
  PyObject* obj = object.ptr();
  if (THPDevice_Check(obj)) {
    return reinterpret_cast<THPDevice*>(obj)->device;
  }
  throw TypeError("Expected device");
}

inline Dtype py_object_to_dtype(py::object object) {
  PyObject* obj = object.ptr();
  if (THPDtype_Check(obj)) {
    return reinterpret_cast<THPDtype*>(obj)->scalar_type;
  }
  throw TypeError("Expected dtype");
}

template <typename ModuleType>
using PyModuleClass =
    py::class_<ModuleType, torch::nn::Module, std::shared_ptr<ModuleType>>;

/// Dynamically creates a subclass of `torch.nn.cpp.ModuleWrapper` that is also
/// a subclass of `torch.nn.Module`, and passes it the user-provided C++ module
/// to which it delegates all calls.
template <typename ModuleType>
void bind_cpp_module_wrapper(
    py::module module,
    PyModuleClass<ModuleType> cpp_class,
    const char* name) {
  // Grab the `torch.nn.cpp.ModuleWrapper` class, which we'll subclass
  // with a dynamically created class below.
  py::object cpp_module =
      py::module::import("torch.nn.cpp").attr("ModuleWrapper");

  // Grab the `type` class which we'll use as a metaclass to create a new class
  // dynamically.
  py::object type_metaclass =
      py::reinterpret_borrow<py::object>((PyObject*)&PyType_Type);

  // The `ModuleWrapper` constructor copies all functions to its own `__dict__`
  // in its constructor, but we do need to give our dynamic class a constructor.
  // Inside, we construct an instance of the original C++ module we're binding
  // (the `torch::nn::Module` subclass), and then forward it to the
  // `ModuleWrapper` constructor.
  py::dict attributes;

  // `type()` always needs a `str`, but pybind11's `str()` method always creates
  // a `unicode` object.
  py::object name_str = py::str(name);

  // Dynamically create the subclass of `ModuleWrapper`, which is a subclass of
  // `torch.nn.Module`, and will delegate all calls to the C++ module we're
  // binding.
  py::object wrapper_class =
      type_metaclass(name_str, py::make_tuple(cpp_module), attributes);

  // The constructor of the dynamic class calls `ModuleWrapper.__init__()`,
  // which replaces its methods with those of the C++ module.
  wrapper_class.attr("__init__") = py::cpp_function(
      [cpp_module, cpp_class](
          py::object self, py::args args, py::kwargs kwargs) {
        cpp_module.attr("__init__")(self, cpp_class(*args, **kwargs));
      },
      py::is_method(wrapper_class));

  // Calling `my_module.my_class` now means that `my_class` is a subclass of
  // `ModuleWrapper`, and whose methods call into the C++ module we're binding.
  module.attr(name) = wrapper_class;
}
} // namespace detail

/// Adds method bindings for a pybind11 `class_` that binds an `nn::Module`
/// subclass.
///
/// Say you have a pybind11 class object created with `py::class_<Net>(m,
/// "Net")`. This function will add all the necessary `.def()` calls to bind the
/// `nn::Module` base class' methods, such as `train()`, `eval()` etc. into
/// Python.
///
/// Users should prefer to use `bind_module` if possible.
template <typename ModuleType, typename... Extra>
py::class_<ModuleType, Extra...> add_module_bindings(
    py::class_<ModuleType, Extra...> module) {
  // clang-format off
  return module
      .def("train",
          [](ModuleType& module, bool mode) { module.train(mode); },
          py::arg("mode") = true)
      .def("eval", [](ModuleType& module) { module.eval(); })
      .def("clone", [](ModuleType& module) { return module.clone(); })
      .def_property_readonly(
          "training", [](ModuleType& module) { return module.is_training(); })
      .def("zero_grad", [](ModuleType& module) { module.zero_grad(); })
      .def_property_readonly( "_parameters", [](ModuleType& module) {
            return module.named_parameters(/*recurse=*/false);
          })
      .def("parameters", [](ModuleType& module, bool recurse) {
            return module.parameters(recurse);
          },
          py::arg("recurse") = true)
      .def("named_parameters", [](ModuleType& module, bool recurse) {
            return module.named_parameters(recurse);
          },
          py::arg("recurse") = true)
      .def_property_readonly("_buffers", [](ModuleType& module) {
            return module.named_buffers(/*recurse=*/false);
          })
      .def("buffers", [](ModuleType& module, bool recurse) {
            return module.buffers(recurse); },
          py::arg("recurse") = true)
      .def("named_buffers", [](ModuleType& module, bool recurse) {
            return module.named_buffers(recurse);
          },
          py::arg("recurse") = true)
      .def_property_readonly(
        "_modules", [](ModuleType& module) { return module.named_children(); })
      .def("modules", [](ModuleType& module) { return module.modules(); })
      .def("named_modules",
          [](ModuleType& module, py::object /* unused */, std::string prefix) {
            return module.named_modules(std::move(prefix));
          },
          py::arg("memo") = py::none(),
          py::arg("prefix") = std::string())
      .def("children", [](ModuleType& module) { return module.children(); })
      .def("named_children",
          [](ModuleType& module) { return module.named_children(); })
      .def("to", [](ModuleType& module, py::object object, bool non_blocking) {
            if (THPDevice_Check(object.ptr())) {
              module.to(
                  reinterpret_cast<THPDevice*>(object.ptr())->device,
                  non_blocking);
            } else {
              module.to(detail::py_object_to_dtype(object), non_blocking);
            }
          },
          py::arg("dtype_or_device"),
          py::arg("non_blocking") = false)
      .def("to",
          [](ModuleType& module,
             py::object device,
             py::object dtype,
             bool non_blocking) {
              if (device.is_none()) {
                module.to(detail::py_object_to_dtype(dtype), non_blocking);
              } else if (dtype.is_none()) {
                module.to(detail::py_object_to_device(device), non_blocking);
              } else {
                module.to(
                    detail::py_object_to_device(device),
                    detail::py_object_to_dtype(dtype),
                    non_blocking);
              }
          },
          py::arg("device"),
          py::arg("dtype"),
          py::arg("non_blocking") = false)
      .def("cuda", [](ModuleType& module) { module.to(kCUDA); })
      .def("cpu", [](ModuleType& module) { module.to(kCPU); })
      .def("float", [](ModuleType& module) { module.to(kFloat32); })
      .def("double", [](ModuleType& module) { module.to(kFloat64); })
      .def("half", [](ModuleType& module) { module.to(kFloat16); })
      .def("__str__", [](ModuleType& module) { return module.name(); })
      .def("__repr__", [](ModuleType& module) { return module.name(); });
  // clang-format on
}

/// Creates a pybind11 class object for an `nn::Module` subclass type and adds
/// default bindings.
///
/// After adding the default bindings, the class object is returned, such that
/// you can add more bindings.
///
/// Example usage:
/// \rst
/// .. code-block:: cpp
///
///   struct Net : torch::nn::Module {
///     Net(int in, int out) { }
///     torch::Tensor forward(torch::Tensor x) { return x; }
///   };
///
///   PYBIND11_MODULE(my_module, m) {
///     torch::python::bind_module<Net>(m, "Net")
///       .def(py::init<int, int>())
///       .def("forward", &Net::forward);
///  }
/// \endrst
template <typename ModuleType, bool force_enable = false>
torch::disable_if_t<
    torch::detail::has_forward<ModuleType>::value && !force_enable,
    detail::PyModuleClass<ModuleType>>
bind_module(py::module module, const char* name) {
  py::module cpp = module.def_submodule("cpp");
  auto cpp_class =
      add_module_bindings(detail::PyModuleClass<ModuleType>(cpp, name));
  detail::bind_cpp_module_wrapper(module, cpp_class, name);
  return cpp_class;
}

/// Creates a pybind11 class object for an `nn::Module` subclass type and adds
/// default bindings.
///
/// After adding the default bindings, the class object is returned, such that
/// you can add more bindings.
///
/// If the class has a `forward()` method, it is automatically exposed as
/// `forward()` and `__call__` in Python.
///
/// Example usage:
/// \rst
/// .. code-block:: cpp
///
///   struct Net : torch::nn::Module {
///     Net(int in, int out) { }
///     torch::Tensor forward(torch::Tensor x) { return x; }
///   };
///
///   PYBIND11_MODULE(my_module, m) {
///     torch::python::bind_module<Net>(m, "Net")
///       .def(py::init<int, int>())
///       .def("forward", &Net::forward);
///  }
/// \endrst
template <
    typename ModuleType,
    typename =
        torch::enable_if_t<torch::detail::has_forward<ModuleType>::value>>
detail::PyModuleClass<ModuleType> bind_module(
    py::module module,
    const char* name) {
  return bind_module<ModuleType, /*force_enable=*/true>(module, name)
      .def("forward", &ModuleType::forward)
      .def("__call__", &ModuleType::forward);
}
} // namespace python
} // namespace torch

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值