数据集取法导致速度不一样的问题

今天写作业的时候,取mnist数据集
写了两种取数据的方法,一个要跑一小时,一个只用两分钟
记录一下

先上慢的

	(input_a,input_b,label) = dataset
    input_a = np.array(input_a)
    input_b = np.array(input_b)
    label = np.array(label)
    index = np.arange(len(dataset[0]))
    #print(index)
    np.random.shuffle(index)
    input_a_n = input_a[index]
    input_b_n = input_b[index]
    input_a_n = input_a_n[:batch_size]
    input_b_n = input_b_n[:batch_size]
    labels_n = label[index]
    labels_n = labels_n[:batch_size]
    #print('shape a,shape b,shape label',input_a_n.shape,input_b_n.shape,labels_n.shape)

因为这里先将下标打乱,然后将数组整个根据下标也打乱,主要时间消耗在将数组打乱这一步上

看一下快的

input_a_n, input_b_n, labels_n = [], [], []
    index = [i for i in range(0, len(dataset[0]))]
    #print(index)
    np.random.shuffle(index)
    for i in range(batch_size):
        j = index[i]
        input_a_n.append(dataset[0][j])
        input_b_n.append(dataset[1][j])
        labels_n.append(int(dataset[2][j]))

这里先是将下标打乱,然后根据打乱的下标取数,这样就不用将数组打乱,速度提升很多

水平有限,O(∩_∩)O哈哈~

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值