import numpy as np
import torch
a = torch.rand((6,20,4))
a = a.cpu().numpy()
aaa = []
for i in range(0,6):
aa = []
for j,data in enumerate(a[i,:,-1]):
if data > 0.5:
b = a.tolist()
aa.append(b[i][j])
else: continue
# 在每个子列表中随机抽取相同数量的值
aa_np = np.array(aa)
# np二维数组的随机sample
row_rand_aa = np.arange(aa_np.shape[0])
np.random.shuffle(row_rand_aa)
aa_sample_np = aa_np[row_rand_aa[0:6]]
aa = aa_sample_np.tolist()
aaa.append(aa)
aaa_np = np.array(aaa) # 子列表维度不相等无法转换成数组
print(aaa_np)
行随机抽取
row_rand_array = np.arange(array.shape[0])
np.random.shuffle(row_rand_array)
row_rand = array[row_rand_array[0:2]]
列随机抽取
col_rand_array = np.arange(array.shape[1])
np.random.shuffle(col_rand_array)
col_rand = array[col_rand_array[0:2]]