pytorch根据labels对应位置取值 | 换一句话省两小时

问题

给出一个列表i_labels = [1,0,1,0],当取label ==1时,取images_true中的对应位置元素,当label == 0时,取images_false中对应位置的元素。如:

解决方法1:torch.stack

images = torch.stack([images_true[idx] if label == 1 else images_false[idx] for idx, label in enumerate(i_labels)])

解决方法2:对位相乘后相加

s = i_labels[:,None,None,None]
images = images * s + images_false * (1 - s)
images

比较

解决方法1解决方法2运行时间,均运行10W次,统计运行时间

解决方法1 耗时7.06秒,解决方法2 耗时4.14秒。

总结

不同语句的运行时间是存在很大差距的,特别是用for语句。在我的30W数据集上跑,在一个epoch中使用解决方法2解决方法1 ,可以节省2个小时。一条语句,省两个小时。

如您有更快的方式,欢迎讨论,让速度更快吧。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值