matlab、numpy的indexing差异

用 index 数组从数据集里抽取一批数据,matlab 的写法与 python numpy 的写法不同:matlab 除了在 sample 那一维放索引数组,还要在其余维度加冒号 :(表示该维全选?)。

numpy

import numpy as np


images = np.zeros([5, 32, 32, 3])  # [n, H, W, C]
labels = np.ones([5, 10])  # [n, c]
indices = np.array([0, 2, 1])  # 0-base
print("indices:", indices)  # [0 2 1]

image_batch = images[indices]
label_batch = labels[indices]
print("image batch:", image_batch.shape)  # (3, 32, 32, 3)
print("label batch:", label_batch.shape)  # (3, 10)

matlab

images = zeros(5, 32, 32, 3);  % [n, H, W, C]
labels = ones(5, 10);  % [n, c]
indices = int8([0, 2, 1])' + 1;  % 1-base
fprintf("indices:"), disp(indices);  % [1, 3, 2]


image_batch = images(indices);
label_batch = labels(indices);
fprintf("image batch 1:"), disp(size(image_batch));  % (3, 1)
fprintf("label batch 1:"), disp(size(label_batch));  % (3, 1)

image_batch = images(indices, :);
label_batch = labels(indices, :);
fprintf("image batch 2:"), disp(size(image_batch));  % (3, 3072), 其中 3072 = 32 * 32 * 3
fprintf("label batch 2:"), disp(size(label_batch));  % (3, 10)

image_batch = images(indices, :, :);
fprintf("image batch 3:"), disp(size(image_batch));  % (3, 32, 96), 其中 96 = 32 * 3

image_batch = images(indices, :, :, :);
fprintf("image batch 4:"), disp(size(image_batch));  % (3, 32, 32, 3)
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值