今天写作业的时候,取mnist数据集
写了两种取数据的方法,一个要跑一小时,一个只用两分钟
记录一下
先上慢的
(input_a,input_b,label) = dataset
input_a = np.array(input_a)
input_b = np.array(input_b)
label = np.array(label)
index = np.arange(len(dataset[0]))
#print(index)
np.random.shuffle(index)
input_a_n = input_a[index]
input_b_n = input_b[index]
input_a_n = input_a_n[:batch_size]
input_b_n = input_b_n[:batch_size]
labels_n = label[index]
labels_n = labels_n[:batch_size]
#print('shape a,shape b,shape label',input_a_n.shape,input_b_n.shape,labels_n.shape)
因为这里先将下标打乱,然后将数组整个根据下标也打乱,主要时间消耗在将数组打乱这一步上
看一下快的
input_a_n, input_b_n, labels_n = [], [], []
index = [i for i in range(0, len(dataset[0]))]
#print(index)
np.random.shuffle(index)
for i in range(batch_size):
j = index[i]
input_a_n.append(dataset[0][j])
input_b_n.append(dataset[1][j])
labels_n.append(int(dataset[2][j]))
这里先是将下标打乱,然后根据打乱的下标取数,这样就不用将数组打乱,速度提升很多
水平有限,O(∩_∩)O哈哈~