功能
torch.bucketize
是 PyTorch 中的一个函数,用于将张量中的值按照指定的边界(bins)进行离散化。它返回每个值所属分箱(bin)的索引。
基本语法
torch.bucketize(input, boundaries, *, out_int32=False, right=False)
参数
-
input
:- 输入张量,包含需要离散化的值。
-
boundaries
:- 分箱边界,必须是单调递增的 1D 张量(或 Python 数组)。
-
out_int32
(可选,默认False
):- 如果为
True
,返回的分箱索引是int32
类型。 - 如果为
False
,返回int64
类型。
- 如果为
-
right
(可选,默认False
):- 如果为
False
,分箱边界为左闭右开[boundaries[i-1], boundaries[i])
。 - 如果为
True
,分箱边界为左开右闭(boundaries[i-1], boundaries[i]]
。
- 如果为
返回值
- 一个与输入形状相同的张量,元素为每个输入值所属的分箱索引。
使用场景
- 数据离散化:将连续数据分成固定范围的区间。
- 编码分箱:为数据分箱后生成索引,用于特征工程或模型输入。
示例代码
1. 基本用法
import torch
# 输入数据
input = torch.tensor([1.5, 2.5, 3.5, 4.5, 5.5])
# 分箱边界
boundaries = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
# 计算分箱索引
output = torch.bucketize(input, boundaries)
print("Input:", input)
print("Boundaries:", boundaries)
print("Output (Indices):", output)
2. 批量分箱
# 批量输入
inputs = torch.tensor([[1.5, 4.0, 6.0], [2.1, 3.3, 5.5]])
boundaries = torch.tensor([1.0, 3.0, 5.0])
# 分箱索引
batch_output = torch.bucketize(inputs, boundaries)
print("Inputs:", inputs)
print("Boundaries:", boundaries)
print("Batch Output:", batch_output)
输出:
Inputs: tensor([[1.5000, 4.0000, 6.0000],
[2.1000, 3.3000, 5.5000]])
Boundaries: tensor([1., 3., 5.])
Batch Output: tensor([[1, 2, 3],
[1, 2, 3]])
3. 结合 One-Hot 编码
# 分箱边界
boundaries = torch.tensor([0.0, 2.0, 4.0, 6.0])
# 输入数据
input_data = torch.tensor([1.5, 3.0, 5.5])
# 分箱索引
indices = torch.bucketize(input_data, boundaries)
# 转换为 One-Hot 编码
one_hot = torch.nn.functional.one_hot(indices, num_classes=len(boundaries) + 1)
print("Input Data:", input_data)
print("Indices:", indices)
print("One-Hot Encoding:\n", one_hot)
输出:
Input Data: tensor([1.5000, 3.0000, 5.5000])
Indices: tensor([1, 2, 3])
One-Hot Encoding:
tensor([[0, 1, 0, 0, 0],
[0, 0, 1, 0, 0],
[0, 0, 0, 1, 0]])
总结
torch.bucketize
是一种高效的分箱工具,适用于连续数据离散化、特征工程、分箱索引生成等场景。- 配合 One-Hot 编码,可以进一步将离散索引用于神经网络输入。