torch.bucketize函数介绍

功能

torch.bucketize 是 PyTorch 中的一个函数,用于将张量中的值按照指定的边界(bins)进行离散化。它返回每个值所属分箱(bin)的索引。

基本语法

torch.bucketize(input, boundaries, *, out_int32=False, right=False)

参数

  1. input:

    • 输入张量,包含需要离散化的值。
  2. boundaries:

    • 分箱边界,必须是单调递增的 1D 张量(或 Python 数组)。
  3. out_int32 (可选,默认 False):

    • 如果为 True,返回的分箱索引是 int32 类型。
    • 如果为 False,返回 int64 类型。
  4. right (可选,默认 False):

    • 如果为 False,分箱边界为左闭右开 [boundaries[i-1], boundaries[i])
    • 如果为 True,分箱边界为左开右闭 (boundaries[i-1], boundaries[i]]

返回值

  • 一个与输入形状相同的张量,元素为每个输入值所属的分箱索引。

使用场景

  • 数据离散化:将连续数据分成固定范围的区间。
  • 编码分箱:为数据分箱后生成索引,用于特征工程或模型输入。

示例代码

1. 基本用法
import torch

# 输入数据
input = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5])
# 分箱边界
boundaries = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])

# 计算分箱索引
output = torch.bucketize(input, boundaries)

print("Input:", input)
print("Boundaries:", boundaries)
print("Output (Indices):", output)
2. 批量分箱
# 批量输入
inputs = torch.tensor([[1.5, 4.0, 6.0], [2.1, 3.3, 5.5]])
boundaries = torch.tensor([1.0, 3.0, 5.0])

# 分箱索引
batch_output = torch.bucketize(inputs, boundaries)

print("Inputs:", inputs)
print("Boundaries:", boundaries)
print("Batch Output:", batch_output)

输出:

Inputs: tensor([[1.5000, 4.0000, 6.0000],
                [2.1000, 3.3000, 5.5000]])
Boundaries: tensor([1., 3., 5.])
Batch Output: tensor([[1, 2, 3],
                      [1, 2, 3]])
3. 结合 One-Hot 编码
# 分箱边界
boundaries = torch.tensor([0.0, 2.0, 4.0, 6.0])
# 输入数据
input_data = torch.tensor([1.5, 3.0, 5.5])

# 分箱索引
indices = torch.bucketize(input_data, boundaries)

# 转换为 One-Hot 编码
one_hot = torch.nn.functional.one_hot(indices, num_classes=len(boundaries) + 1)

print("Input Data:", input_data)
print("Indices:", indices)
print("One-Hot Encoding:\n", one_hot)

输出:

Input Data: tensor([1.5000, 3.0000, 5.5000])
Indices: tensor([1, 2, 3])
One-Hot Encoding:
 tensor([[0, 1, 0, 0, 0],
         [0, 0, 1, 0, 0],
         [0, 0, 0, 1, 0]])

总结

  • torch.bucketize 是一种高效的分箱工具,适用于连续数据离散化、特征工程、分箱索引生成等场景。
  • 配合 One-Hot 编码,可以进一步将离散索引用于神经网络输入。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值