对于单标签数据,其 label 可能是用一个整数表示,想转化成 one-hot label。注意原 class IDs 可能是 0-base 的,要 +1 转成 1-base。
Code
N_CLASS = 10;
labels = [1 3 2 5 0 4]; % 0-base
% isvector(labels)
% isrow(labels)
% iscolumn(labels)
% isa(int32(labels), 'integer')
% class(labels)
L = onehot(int32(labels + 1), N_CLASS); % shift to 1-base
disp(L);
function L = onehot(vec, n_class)
% convert sparse class IDs to one-hot label vectors
% Input:
% vec: [1, n] row or [n, 1] column vector
% n_class: int, # of classes
% Output:
% L: [n, n_class] one-hot class label vectors
%----------------------------------------------
assert(isvector(vec), "`vec` must be a vector");
assert(isa(vec, 'integer'), "`vec` must be integers");
if isrow(vec)
vec = vec'; % -> column vector
end
I = eye(n_class);
L = I(vec, :);
end