用 index 数组从数据集里抽取一批数据,matlab 的写法与 python numpy 的写法不同:matlab 除了在 sample 那一维放索引数组,还要在其余维度加冒号 :
(表示该维全选?)。
numpy
import numpy as np
images = np.zeros([5, 32, 32, 3]) # [n, H, W, C]
labels = np.ones([5, 10]) # [n, c]
indices = np.array([0, 2, 1]) # 0-base
print("indices:", indices) # [0 2 1]
image_batch = images[indices]
label_batch = labels[indices]
print("image batch:", image_batch.shape) # (3, 32, 32, 3)
print("label batch:", label_batch.shape) # (3, 10)
matlab
images = zeros(5, 32, 32, 3); % [n, H, W, C]
labels = ones(5, 10); % [n, c]
indices = int8([0, 2, 1])' + 1; % 1-base
fprintf("indices:"), disp(indices); % [1, 3, 2]
image_batch = images(indices);
label_batch = labels(indices);
fprintf("image batch 1:"), disp(size(image_batch)); % (3, 1)
fprintf("label batch 1:"), disp(size(label_batch)); % (3, 1)
image_batch = images(indices, :);
label_batch = labels(indices, :);
fprintf("image batch 2:"), disp(size(image_batch)); % (3, 3072), 其中 3072 = 32 * 32 * 3
fprintf("label batch 2:"), disp(size(label_batch)); % (3, 10)
image_batch = images(indices, :, :);
fprintf("image batch 3:"), disp(size(image_batch)); % (3, 32, 96), 其中 96 = 32 * 3
image_batch = images(indices, :, :, :);
fprintf("image batch 4:"), disp(size(image_batch)); % (3, 32, 32, 3)