【琐碎】如何理解zip(*batch)

读代码的时候看到data, label = zip(*batch)这样一句,很好奇它究竟实现了什么操作
利用zip(*)命令将batch解压开,当有多个迭代器,希望它们能以相同索引一起输出时,会使用zip(*)命令

class Student(object):
    def __init__(self, score):
        self.score = score
    def __iter__(self):
        return self  # 对于迭代器来说,__iter__ 返回的是它自身self,也就是返回迭代器。

    def __next__(self):
        if self.score < 120:
            self.score += 1
            return [self.score]*5,2
        else:
            raise StopIteration()
test = Student(95)
print(isinstance(test,  Iterable))
print(isinstance(test,  Iterator))
ans=zip(*test)
data,label=list(ans)
print(data,label)
print(len(data),len(label))

输出

True
True
([96, 96, 96, 96, 96], [97, 97, 97, 97, 97], [98, 98, 98, 98, 98], [99, 99, 99, 99, 99], [100, 100, 100, 100, 100], [101, 101, 101, 101, 101], [102, 102, 102, 102, 102], [103, 103, 103, 103, 103], [104, 104, 104, 104, 104], [105, 105, 105, 105, 105], [106, 106, 106, 106, 106], [107, 107, 107, 107, 107], [108, 108, 108, 108, 108], [109, 109, 109, 109, 109], [110, 110, 110, 110, 110], [111, 111, 111, 111, 111], [112, 112, 112, 112, 112], [113, 113, 113, 113, 113], [114, 114, 114, 114, 114], [115, 115, 115, 115, 115], [116, 116, 116, 116, 116], [117, 117, 117, 117, 117], [118, 118, 118, 118, 118], [119, 119, 119, 119, 119], [120, 120, 120, 120, 120]) (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
25 25

参考
https://blog.csdn.net/u010848594/article/details/106026597

  • 4
    点赞
  • 8
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在Python中,`zip()`函数可以将多个可迭代对象打包成一个元组列表。而在深度学习中,经常需要将数据分成小批次进行训练,使用`zip()`函数可以将多个数据组成一个小批次。但是,通常我们还需要把一个小批次的数据拆分成多个数组,这时就可以使用`list(zip(*batch))`的方式。其中,`batch`是一个小批次的数据,它通常是由多个数组组成的元组列表。 具体来说,`list(zip(*batch))`的作用是将一个小批次的数据进行转置。例如,如果一个小批次的数据包含3个数组,每个数组的形状分别为`(batch_size, input_size)`,那么使用`zip()`函数将它们打包后得到的元组列表的长度为`batch_size`,每个元组包含3个元素,分别是3个数组中的元素。而使用`list(zip(*batch))`则可以将这个元组列表转置为3个数组,每个数组的形状为`(batch_size, input_size)`。 下面是一个使用`list(zip(*batch))`的简单示例: ``` import torch # 模拟一批数据,包含两个数组 batch = [(torch.randn(3), torch.randn(3)), (torch.randn(3), torch.randn(3))] # 将数据转置为两个数组 inputs, targets = list(zip(*batch)) # 打印转置后的数组形状 print(inputs[0].shape) print(targets[0].shape) ``` 在上面的示例中,我们首先模拟了一个包含两个数组的小批次数据,每个数组的形状为`(3,)`。然后,我们使用`list(zip(*batch))`将它们转置为两个数组,每个数组的形状为`(2, 3)`。最后,我们打印了转置后的数组形状,可以看到它们的形状已经发生了变化。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值