NVIDIA CUDA Python编程框架--Warp开发文档第六章: 互操作

NVIDIA CUDA Python编程框架–Warp开发文档第六章: 互操作

Warp 可以通过标准接口协议与其他基于 Python 的框架(例如 NumPy)进行互操作。
在这里插入图片描述

NumPy

Warp 数组可以通过 warp.array.numpy() 方法转换为 NumPy 数组。 当 Warp 数组位于 cpu 设备上时,这将返回底层 Warp 分配的零拷贝视图。 如果数组位于 cuda 设备上,那么它将首先被复制回临时缓冲区并复制到 NumPy。

Warp CPU 数组还实现了 __array_interface__ 协议,因此可以直接用于构造 NumPy 数组:

w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu")
a = np.array(w)
print(a)
> [1. 2. 3.]

为了方便起见,还可以使用数据类型转换实用程序:

warp_type = wp.float32
...
numpy_type = wp.dtype_to_numpy(warp_type)
...
a = wp.zeros(n, dtype=warp_type)
b = np.zeros(n, dtype=numpy_type)

  • warp.dtype_from_numpy(numpy_dtype)
    返回与 NumPy dtype 对应的 Warp dtype。

  • warp.dtype_to_numpy(warp_dtype)
    返回与 Warp dtype 对应的 NumPy dtype。

PyTorch

Warp 提供了辅助函数来将数组与 PyTorch 相互转换:

w = wp.array([1.0, 2.0, 3.0], dtype=float, device="cpu")

# convert to Torch tensor
t = wp.to_torch(w)

# convert from Torch tensor
w = wp.from_torch(t)

这些辅助函数允许在 Warp 数组与 PyTorch 张量之间进行转换,而无需复制底层数据。 同时,如果可用,梯度数组和张量会与 PyTorch autograd 张量相互转换,从而允许在 PyTorch autograd 计算中使用 Warp 数组。

示例:使用 warp.from_torch() 进行优化

使用 warp.from_torch 通过 PyTorch 的 Adam 优化器最小化以 Warp 编写的 2D 点数组上的损失函数的示例用法如下:

import warp as wp
import torch

wp.init()

@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# indicate requires_grad so that Warp can accumulate gradients in the grad buffers
xs = torch.randn(100, 2, requires_grad=True)
l = torch.zeros(1, requires_grad=True)
opt = torch.optim.Adam([xs], lr=0.1)

wp_xs = wp.from_torch(xs)
wp_l = wp.from_torch(l)

tape = wp.Tape()
with tape:
    # record the loss function kernel launch on the tape
    wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)

for i in range(500):
    tape.zero()
    tape.backward(loss=wp_l)  # compute gradients
    # now xs.grad will be populated with the gradients computed by Warp
    opt.step()  # update xs (and thereby wp_xs)

    # these lines are only needed for evaluating the loss
    # (the optimization just needs the gradient, not the loss value)
    wp_l.zero_()
    wp.launch(loss, dim=len(xs), inputs=[wp_xs], outputs=[wp_l], device=wp_xs.device)
    print(f"{i}\tloss: {l.item()}")

示例:使用 warp.to_torch 进行优化

当我们直接在 Warp 中声明优化变量并使用 warp.to_torch 将它们转换为 PyTorch 张量时,需要更少的代码。 在这里,我们重新审视上面的相同示例,现在只需要一次到PyTorch张量的转换即可为 Adam 提供优化变量:

import warp as wp
import numpy as np
import torch

wp.init()

@wp.kernel()
def loss(xs: wp.array(dtype=float, ndim=2), l: wp.array(dtype=float)):
    tid = wp.tid()
    wp.atomic_add(l, 0, xs[tid, 0] ** 2.0 + xs[tid, 1] ** 2.0)

# initialize the optimization variables in Warp
xs = wp.array(np.random.randn(100, 2), dtype=wp.float32, requires_grad=True)
l = wp.zeros(1, dtype=wp.float32, requires_grad=True)
# just a single wp.to_torch call is needed, Adam optimizes using the Warp array gradients
opt = torch.optim.Adam([wp.to_torch(xs)], lr=0.1)

tape = wp.Tape()
with tape:
    wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)

for i in range(500):
    tape.zero()
    tape.backward(loss=l)
    opt.step()

    l.zero_()
    wp.launch(loss, dim=len(xs), inputs=[xs], outputs=[l], device=xs.device)
    print(f"{i}\tloss: {l.numpy()[0]}")

示例:使用 torch.autograd.function 进行优化

人们可以通过定义 torch.autograd.function 类在 PyTorch 图中插入 Warp 内核启动,这需要定义前向和后向函数。 将传入的 torch 数组映射到 Warp 数组后,可以以通常的方式启动 Warp 内核。 在向后传递中,可以通过在 wp.launch() 中设置 adjoint = True 来启动同一内核的伴随程序。 或者,用户可以选择依赖 Warp 的胶带。 在下面的示例中,我们演示了如何使用 Warp 在优化上下文中评估 Rosenbrock 函数:

import warp as wp
import numpy as np
import torch

wp.init()

pvec2 = wp.types.vector(length=2, dtype=wp.float32)

# Define the Rosenbrock function
@wp.func
def rosenbrock(x: float, y: float):
    return (1.0 - x) ** 2.0 + 100.0 * (y - x**2.0) ** 2.0

@wp.kernel
def eval_rosenbrock(
    xs: wp.array(dtype=pvec2),
    # outputs
    z: wp.array(dtype=float),
):
    i = wp.tid()
    x = xs[i]
    z[i] = rosenbrock(x[0], x[1])


class Rosenbrock(torch.autograd.Function):
    @staticmethod
    def forward(ctx, xy, num_points):
        # ensure Torch operations complete before running Warp
        wp.synchronize_device()

        ctx.xy = wp.from_torch(xy, dtype=pvec2, requires_grad=True)
        ctx.num_points = num_points

        # allocate output
        ctx.z = wp.zeros(num_points, requires_grad=True)

        wp.launch(
            kernel=eval_rosenbrock,
            dim=ctx.num_points,
            inputs=[ctx.xy],
            outputs=[ctx.z]
        )

        # ensure Warp operations complete before returning data to Torch
        wp.synchronize_device()

        return wp.to_torch(ctx.z)

    @staticmethod
    def backward(ctx, adj_z):
        # ensure Torch operations complete before running Warp
        wp.synchronize_device()

        # map incoming Torch grads to our output variables
        ctx.z.grad = wp.from_torch(adj_z)

        wp.launch(
            kernel=eval_rosenbrock,
            dim=ctx.num_points,
            inputs=[ctx.xy],
            outputs=[ctx.z],
            adj_inputs=[ctx.xy.grad],
            adj_outputs=[ctx.z.grad],
            adjoint=True
        )

        # ensure Warp operations complete before returning data to Torch
        wp.synchronize_device()

        # return adjoint w.r.t. inputs
        return (wp.to_torch(ctx.xy.grad), None)


num_points = 1500
learning_rate = 5e-2

torch_device = wp.device_to_torch(wp.get_device())

rng = np.random.default_rng(42)
xy = torch.tensor(rng.normal(size=(num_points, 2)), dtype=torch.float32, requires_grad=True, device=torch_device)
opt = torch.optim.Adam([xy], lr=learning_rate)

for _ in range(10000):
    # step
    opt.zero_grad()
    z = Rosenbrock.apply(xy, num_points)
    z.backward(torch.ones_like(z))

    opt.step()

# minimum at (1, 1)
xy_np = xy.numpy(force=True)
print(np.mean(xy_np, axis=0))

  • warp.from_torch(t, dtype=None, requires_grad=None, grad=None)
    将 Torch 张量转换为 Warp 数组,而不复制数据。

    • 参数:
      t (torch.Tensor) – 要包装的torch张量。

    • dtype (warp.dtype, 可选) – 生成的 Warp 数组的目标数据类型。 默认为映射到 Warp 数组值类型的张量值类型。

    • require_grad (bool, 可选) – 结果数组是否应该包含张量的梯度(如果存在)(否则将分配梯度张量)。 默认为张量的requires_grad值。

    • 返回:
      warp数组。

    • 返回类型:
      warp数组

  • warp.to_torch(a, require_grad=None)
    将 Warp 数组转换为 Torch 张量,而不复制数据。

    • 参数:
      a (warp.array) – 要转换的 Warp 数组。

    • require_grad (bool, 可选) – 生成的张量是否应将数组的梯度(如果存在)转换为梯度张量。 默认为数组的requires_grad值。

    • 返回:
      转换后的张量。

    • 返回类型:
      torch张量

  • warp.device_from_torch(torch_device)
    返回与 Torch 设备对应的 Warp 设备。

  • warp.device_to_torch(warp_device)
    返回与 Warp 设备对应的 Torch 设备。

  • warp.dtype_from_torch(torch_dtype)
    返回与 Torch dtype 对应的 Warp dtype。

  • warp.dtype_to_torch(warp_dtype)
    返回对应于 Warp dtype 的 Torch dtype。

CuPy/Numba

Warp GPU 阵列支持 __cuda_array_interface__ 协议,用于与其他 Python GPU 框架共享数据。 目前这是单向的,因此 Warp 数组可以用作任何也支持 __cuda_array_interface__ 协议的框架的输入,但反之则不然。

JAX

通过以下方法支持与 JAX 数组的互操作性。 在内部,它们使用 DLPack 协议以零拷贝方式与 JAX 交换数据:


warp_array = wp.from_jax(jax_array)
jax_array = wp.to_jax(warp_array)

为了获得更好的性能和对流同步行为的控制,最好直接使用 DLPack 协议。

  • warp.from_jax(jax_array, dtype=None)
    将 Jax 数组转换为 Warp 数组而不复制数据。

    • 参数:
      jax_array – 要转换的 Jax 数组。

    • dtype (可选) – 生成的 Warp 数组的目标数据类型。 默认为映射到 Warp 数据类型的 Jax 数组的数据类型。

    • 返回:
      转换后的 Warp 数组。

    • 返回类型:
      warp数组

  • warp.to_jax(warp_array)
    将 Warp 数组转换为 Jax 数组而不复制数据。

    • 参数:
      warp_array (warp.array) – 要转换的 Warp 数组。

    • 返回:
      转换后的 Jax 数组。

*warp.device_from_jax(jax_device)
返回与 Jax 设备对应的 Warp 设备。

  • warp.device_to_jax(warp_device)
    返回与 Warp 设备对应的 Jax 设备。

  • warp.dtype_from_jax(jax_dtype)
    返回与 Jax dtype 对应的 Warp dtype。

  • warp.dtype_to_jax(warp_dtype)
    返回对应于 Warp 数据类型的 Jax 数据类型。

使用 Warp 内核作为 JAX 原语

注意

这是正在开发的实验性功能。

Warp 内核可以用作 JAX 原语,它可用于在 jitted JAX 函数内部调用 Warp 内核:

import warp as wp
import jax
import jax.numpy as jp

# import experimental feature
from warp.jax_experimental import jax_kernel

@wp.kernel
def triple_kernel(input: wp.array(dtype=float), output: wp.array(dtype=float)):
    tid = wp.tid()
    output[tid] = 3.0 * input[tid]

wp.init()

# create a Jax primitive from a Warp kernel
jax_triple = jax_kernel(triple_kernel)

# use the Warp kernel in a Jax jitted function
@jax.jit
def f():
    x = jp.arange(0, 64, dtype=jp.float32)
    return jax_triple(x)

print(f())

由于这是一个实验性功能,因此存在一些限制:

  • 所有内核参数都必须是数组。

  • 内核启动维度是根据第一个参数的形状推断出来的。

  • 在 Warp 内核定义中,输入参数后面跟着输出参数。

  • 必须至少有一个输入参数和至少一个输出参数。

  • 输出形状必须与启动尺寸匹配(即输出形状必须与第一个参数的形状匹配)。

  • 所有数组必须是连续的。

  • 仅支持 CUDA 后端。

以下是具有三个输入和两个输出的操作的示例:

import warp as wp
import jax
import jax.numpy as jp

# import experimental feature
from warp.jax_experimental import jax_kernel

# kernel with multiple inputs and outputs
@wp.kernel
def multiarg_kernel(
    # inputs
    a: wp.array(dtype=float),
    b: wp.array(dtype=float),
    c: wp.array(dtype=float),
    # outputs
    ab: wp.array(dtype=float),
    bc: wp.array(dtype=float),
):
    tid = wp.tid()
    ab[tid] = a[tid] + b[tid]
    bc[tid] = b[tid] + c[tid]

wp.init()

# create a Jax primitive from a Warp kernel
jax_multiarg = jax_kernel(multiarg_kernel)

# use the Warp kernel in a Jax jitted function with three inputs and two outputs
@jax.jit
def f():
    a = jp.full(64, 1, dtype=jp.float32)
    b = jp.full(64, 2, dtype=jp.float32)
    c = jp.full(64, 3, dtype=jp.float32)
    return jax_multiarg(a, b, c)

x, y = f()

print(x)
print(y)

DLPack

Warp 支持 Python Array API 标准 v2022.12 中包含的 DLPack 协议。 请参阅 DLPack 的 Python 规范以供参考。

将外部数组导入 Warp 的规范方法是使用 warp.from_dlpack() 函数:

warp_array = wp.from_dlpack(external_array)

外部数组可以是 PyTorch 张量、Jax 数组或与此版本的 DLPack 协议兼容的任何其他数组类型。 对于 CUDA 数组,此方法要求生产者执行流同步,以确保数组上的操作正确排序。 warp.from_dlpack() 函数要求生产者同步数组所在设备上的当前 Warp 流。 因此,在该设备上的 Warp 内核中使用数组应该是安全的,无需任何额外的同步。

将 Warp 数组导出到外部框架的规范方法是使用该框架中的 from_dlpack() 函数:

jax_array = jax.dlpack.from_dlpack(warp_array)
torch_tensor = torch.utils.dlpack.from_dlpack(warp_array)

对于 CUDA 数组,这会将消费者框架的当前流与数组设备上的当前 Warp 流同步。 因此,在消费者框架中使用包装数组应该是安全的,即使该数组之前已在设备上的 Warp 内核中使用过。

或者,可以通过使用生产者框架提供的 to_dlpack() 函数显式创建 PyCapsule 来共享数组。 此方法可用于不支持 v2022.12 标准的旧版本框架:

warp_array1 = wp.from_dlpack(jax.dlpack.to_dlpack(jax_array))
warp_array2 = wp.from_dlpack(torch.utils.dlpack.to_dlpack(torch_tensor))

jax_array = jax.dlpack.from_dlpack(wp.to_dlpack(warp_array))
torch_tensor = torch.utils.dlpack.from_dlpack(wp.to_dlpack(warp_array))

这种方法通常更快,因为它跳过任何流同步,但必须使用另一种解决方案来确保操作的正确顺序。 在不需要同步的情况下,使用此方法可以产生更好的性能。 在以下情况下这可能是一个不错的选择:

  • 外部框架使用同步 CUDA 默认流。

  • Warp 和外部框架使用相同的 CUDA 流。

  • 另一种同步机制已经到位。

  • warp.from_dlpack(source, dtype=None)
    将源数组或 DLPack 胶囊转换为 Warp 数组,无需复制。

    参数:

    • source – DLPack 兼容的数组或 PyCapsule

    • dtype – 用于解释源数据的可选 Warp 数据类型。

    • 返回:
      一个新的 Warp 数组,使用与输入 pycapsule 相同的底层内存。

    • 返回类型:
      array

  • warp.to_dlpack(wp_array)
    将 Warp 数组转换为另一种类型的 dlpack 兼容数组。

    参数:

    • wp_array (array) – 将被转换的源 Warp 数组。

    • 返回:
      包含 DLManagedTensor 的胶囊,可以将其转换为另一种数组类型,而无需复制底层内存。

  • 28
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

扫地的小何尚

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值