for batch, data in enumerate(dataloader[phase], 1)为什么从索引1开始(enumerate函数的说明)

1. 动机

在学习《深度学习之PyTorch实战计算机视觉》这本书的“模型融合”部分的时候,我遇到了一些问题。具体来说,对于 for batch, data in enumerate(dataloader[phase], 1): 中的 enumerate() 函数该如何理解?

2. enumerate 函数定义

通过查看 enumerate() 的文档:

enumerate(iterable, start=0)
"""Return an enumerate object. 

iterable must be a sequence, an iterator, or some other object which supports iteration. 

The __next__() method of the iterator returned by enumerate() returns a tuple containing a count (from start which defaults to 0) and the values obtained from iterating over iterable."""
  • numerate() 会返回两个值, 一个是索引, 一个是数据
  • numerate() 需要两个参数:
    1. 第一个参数是可迭代的对象
    2. 第二个参数是起始位置, 数据类型为 int

3. enumerate 函数代码示例

seasons = ['Spring', 'Summer', 'Fall', 'Winter']

for idx, data in list(enumerate(seasons)):
    print("idx: {}, data: {}    ".format(idx, data), end="")
    # Res: idx: 0, data: Spring    idx: 1, data: Summer    idx: 2, data: Fall    idx: 3, data: Winter
print("")

for idx, data in list(enumerate(seasons, 0)):
    print("idx: {}, data: {}    ".format(idx, data), end="")
    # Res: idx: 0, data: Spring    idx: 1, data: Summer    idx: 2, data: Fall    idx: 3, data: Winter
print("")

for idx, data in list(enumerate(seasons, start=1)):
    print("idx: {}, data: {}    ".format(idx, data), end="")
    # Res: idx: 1, data: Spring    idx: 2, data: Summer    idx: 3, data: Fall    idx: 4, data: Winter
print("")

for idx, data in list(enumerate(seasons, start=5)):
    print("idx: {}, data: {}    ".format(idx, data), end="")
    # Res: idx: 5, data: Spring    idx: 6, data: Summer    idx: 7, data: Fall    idx: 8, data: Winter

结合以上的代码结果, 我们观察到一个规律:默认情况下,numerate()start 参数默认为 0,即返回的第一个值默认是从0 开始的, 而且这个值只会影响返回索引的值, 对数据没有任何影响。所以这也就解释了在 PyTorch 训练部分, 一般 start=1 而不是 0

4. 解释问题

我们看一下训练的部分代码:

for batch, data in enumerate(dataloader[phase], 1):
    x, y = data
    if torch.cuda.is_available():
        x, y = Variable(x.cuda()), Variable(y.cuda())
    else:
        x, y = Variable(x), Variable(y)

    # 前向传播
    y_pred_1 = model_1(x)
    y_pred_2 = model_2(x)
    
    # 根据权重融合两个输出
    blending_y_pred = y_pred_1 * weight_1 + y_pred_2 * weight_2

    pred_1 = torch.max(y_pred_1.data, 1)[1]
    pred_2 = torch.max(y_pred_2.data, 1)[1]
    blending_pred = torch.max(blending_y_pred.data, 1)[1]

    # 清空梯度
    optimizer_1.zero_grad()
    optimizer_2.zero_grad()

    # 计算损失
    loss_1 = loss_Fn_1(y_pred_1, y)
    loss_2 = loss_Fn_2(y_pred_2, y)

    """
    先判断是在训练还是在验证: 
        如果在训练则开始进行计算反向传播, 并更新梯度
        如果在验证则开始不进行计算反向传播, 不更新梯度
    """
    if phase == "train":
        # 反向传播
        loss_1.backward()
        loss_2.backward()

        # 梯度更新
        optimizer_1.step()
        optimizer_2.step()

    running_loss_1 += loss_1.item()
    running_corrects_1 += torch.sum(pred_1 == y.data)
    running_loss_2 += loss_2.item()
    running_corrects_2 += torch.sum(pred_2 == y.data)
    blending_running_corrects += torch.sum(blending_pred == y.data)

    if batch % 500 == 0 and phase == "train":
        print(
            f"Batch {batch}:\n "
            "--------------------------------------------------------------------\n"
            f"Model_1 Train Loss:{running_loss_1 / batch:.4f}, "
            f"Model_1 Train Acc:{100 * running_corrects_1 / (16 * batch):.4f}\n"
            f"Model_2 Train Loss:{running_loss_2 / batch:.4f}, "
            f"Model_2 Train Acc:{100 * running_corrects_2 / (16 * batch):.4f}\n "
            "--------------------------------------------------------------------\n"
            f"Blending_Model Acc:{100 * blending_running_corrects / (16 * batch):.4f}%"
        )

epoch_loss_1 = running_loss_1 * 16 / len(image_datasets[phase])
epoch_acc_1 = 100 * running_corrects_1 / len(image_datasets[phase])
epoch_loss_2 = running_loss_2 * 16 / len(image_datasets[phase])
epoch_acc_2 = 100 * running_corrects_2 / len(image_datasets[phase])
epoch_blending_acc = 100 * blending_running_corrects / len(image_datasets[phase])

print(
    f"Model_1 Loss:{epoch_loss_1:.4f}, Model_1 Acc:{epoch_acc_1:.4f}%\n "
    f"Model_2 Loss:{epoch_loss_2:.4f}, Model_2 Acc:{epoch_acc_2:.4f}%\n "
    f"Blending_Model Acc:{epoch_blending_acc:.4f}%"
)

time_end = time.time()
print(f"Total Time is:{time_end - time_start:.2f}")

通过阅读代码我们发现, for batch, data in enumerate(dataloader[phase], 1): 将返回的索引给了batch, 而batch 在之后打印输出训练结果的时候有用到(即 if batch % 500 == 0 and phase == "train":)。让 enumerate(data, start=1) 是为了让 batch1 开始。

5. 总结

这篇文章主要探讨的是 enumerate 这个 Python 内置的函数,它会对一个可迭代对象进行包装,返回两个值:

  1. 索引
  2. 可迭代对象本身

其中 enumerate 函数中的 start 参数决定了索引的开始值(默认为 0)。我们可以根据我们的需要设置不同的 start,从而灵活地使用 enumerate 函数。

  • 6
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值