在机器学习和深度学习模型中,独热编码是对类别标签常见的编码形式,对指定类别数目的标签进行编码时,chatgpt给出了解答,很多时候替我省下了很多试错成本。
import numpy as np
from sklearn.preprocessing import OneHotEncoder
def numerical_to_one_hot(labels, num_classes):
"""
Convert numerical labels into one-hot encoded vectors with specified number of classes.
Parameters:
- labels: List or array of numerical labels.
- num_classes: Number of unique classes.
Returns:
- one_hot_encoded: 2D numpy array of one-hot encoded vectors.
"""
labels = np.array(labels).reshape(-1, 1) # Reshape labels to a column vector
encoder = OneHotEncoder(categories='auto', sparse_output=False, drop='if_binary')
encoder.fit(np.arange(num_classes).reshape(-1, 1)) # Fit encoder on possible classes
one_hot_encoded = encoder.transform(labels)
return one_hot_encoded
# Example usage:
numerical_labels = [0, 2, 1, 0, 2]
num_classes = 4
one_hot_encoded = numerical_to_one_hot(numerical_labels, num_classes)
print("Original labels:", numerical_labels)
print("One-hot encoded labels:\n", one_hot_encoded)