pytorch torch.vmap函数介绍

torch.vmap 是 PyTorch 提供的一个高效矢量化映射函数,用于对批量数据上的操作进行自动矢量化。它可以显著提高代码的性能和可读性,避免显式使用循环来操作批量数据。


torch.vmap 的核心功能

  • 对函数进行批量化操作。
  • 自动扩展函数,使其可以作用于批量输入(即 N 个样本)。
  • 提供对批量维度的灵活控制,包括指定输入输出的批量维度。

函数签名

torch.vmap(func, in_dims=0, out_dims=0)
参数
  1. func:

    • 要矢量化的函数(可以是用户定义函数,也可以是 PyTorch 函数)。
    • 必须接收张量作为输入,并返回张量或元组。
  2. in_dims:

    • 指定输入张量的批量维度,默认为 0
    • 如果输入是多个张量,可以传递一个元组,表示每个输入的批量维度。
    • 若 in_dims=None,表示输入不需要矢量化。
  3. out_dims:

    • 指定函数输出的批量维度,默认为 0
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值