1. 动机
在学习《深度学习之PyTorch实战计算机视觉》这本书的“模型融合”部分的时候,我遇到了一些问题。具体来说,对于 for batch, data in enumerate(dataloader[phase], 1):
中的 enumerate()
函数该如何理解?
2. enumerate 函数定义
通过查看 enumerate()
的文档:
enumerate(iterable, start=0)
"""Return an enumerate object.
iterable must be a sequence, an iterator, or some other object which supports iteration.
The __next__() method of the iterator returned by enumerate() returns a tuple containing a count (from start which defaults to 0) and the values obtained from iterating over iterable."""
numerate()
会返回两个值, 一个是索引, 一个是数据numerate()
需要两个参数:- 第一个参数是可迭代的对象
- 第二个参数是起始位置, 数据类型为
int
3. enumerate 函数代码示例
seasons = ['Spring', 'Summer', 'Fall', 'Winter']
for idx, data in list(enumerate(seasons)):
print("idx: {}, data: {} ".format(idx, data), end="")
# Res: idx: 0, data: Spring idx: 1, data: Summer idx: 2, data: Fall idx: 3, data: Winter
print("")
for idx, data in list(enumerate(seasons, 0)):
print("idx: {}, data: {} ".format(idx, data), end="")
# Res: idx: 0, data: Spring idx: 1, data: Summer idx: 2, data: Fall idx: 3, data: Winter
print("")
for idx, data in list(enumerate(seasons, start=1)):
print("idx: {}, data: {} ".format(idx, data), end="")
# Res: idx: 1, data: Spring idx: 2, data: Summer idx: 3, data: Fall idx: 4, data: Winter
print("")
for idx, data in list(enumerate(seasons, start=5)):
print("idx: {}, data: {} ".format(idx, data), end="")
# Res: idx: 5, data: Spring idx: 6, data: Summer idx: 7, data: Fall idx: 8, data: Winter
结合以上的代码结果, 我们观察到一个规律:默认情况下,numerate()
的 start
参数默认为 0
,即返回的第一个值默认是从0
开始的, 而且这个值只会影响返回索引的值, 对数据没有任何影响。所以这也就解释了在 PyTorch 训练部分, 一般 start=1
而不是 0
。
4. 解释问题
我们看一下训练的部分代码:
for batch, data in enumerate(dataloader[phase], 1):
x, y = data
if torch.cuda.is_available():
x, y = Variable(x.cuda()), Variable(y.cuda())
else:
x, y = Variable(x), Variable(y)
# 前向传播
y_pred_1 = model_1(x)
y_pred_2 = model_2(x)
# 根据权重融合两个输出
blending_y_pred = y_pred_1 * weight_1 + y_pred_2 * weight_2
pred_1 = torch.max(y_pred_1.data, 1)[1]
pred_2 = torch.max(y_pred_2.data, 1)[1]
blending_pred = torch.max(blending_y_pred.data, 1)[1]
# 清空梯度
optimizer_1.zero_grad()
optimizer_2.zero_grad()
# 计算损失
loss_1 = loss_Fn_1(y_pred_1, y)
loss_2 = loss_Fn_2(y_pred_2, y)
"""
先判断是在训练还是在验证:
如果在训练则开始进行计算反向传播, 并更新梯度
如果在验证则开始不进行计算反向传播, 不更新梯度
"""
if phase == "train":
# 反向传播
loss_1.backward()
loss_2.backward()
# 梯度更新
optimizer_1.step()
optimizer_2.step()
running_loss_1 += loss_1.item()
running_corrects_1 += torch.sum(pred_1 == y.data)
running_loss_2 += loss_2.item()
running_corrects_2 += torch.sum(pred_2 == y.data)
blending_running_corrects += torch.sum(blending_pred == y.data)
if batch % 500 == 0 and phase == "train":
print(
f"Batch {batch}:\n "
"--------------------------------------------------------------------\n"
f"Model_1 Train Loss:{running_loss_1 / batch:.4f}, "
f"Model_1 Train Acc:{100 * running_corrects_1 / (16 * batch):.4f}\n"
f"Model_2 Train Loss:{running_loss_2 / batch:.4f}, "
f"Model_2 Train Acc:{100 * running_corrects_2 / (16 * batch):.4f}\n "
"--------------------------------------------------------------------\n"
f"Blending_Model Acc:{100 * blending_running_corrects / (16 * batch):.4f}%"
)
epoch_loss_1 = running_loss_1 * 16 / len(image_datasets[phase])
epoch_acc_1 = 100 * running_corrects_1 / len(image_datasets[phase])
epoch_loss_2 = running_loss_2 * 16 / len(image_datasets[phase])
epoch_acc_2 = 100 * running_corrects_2 / len(image_datasets[phase])
epoch_blending_acc = 100 * blending_running_corrects / len(image_datasets[phase])
print(
f"Model_1 Loss:{epoch_loss_1:.4f}, Model_1 Acc:{epoch_acc_1:.4f}%\n "
f"Model_2 Loss:{epoch_loss_2:.4f}, Model_2 Acc:{epoch_acc_2:.4f}%\n "
f"Blending_Model Acc:{epoch_blending_acc:.4f}%"
)
time_end = time.time()
print(f"Total Time is:{time_end - time_start:.2f}")
通过阅读代码我们发现, for batch, data in enumerate(dataloader[phase], 1):
将返回的索引给了batch
, 而batch
在之后打印输出训练结果的时候有用到(即 if batch % 500 == 0 and phase == "train":
)。让 enumerate(data, start=1)
是为了让 batch
从 1
开始。
5. 总结
这篇文章主要探讨的是 enumerate
这个 Python 内置的函数,它会对一个可迭代对象进行包装,返回两个值:
- 索引
- 可迭代对象本身
其中 enumerate
函数中的 start
参数决定了索引的开始值(默认为 0)。我们可以根据我们的需要设置不同的 start
,从而灵活地使用 enumerate
函数。