image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
true_label = os.path.normpath(image_filepath).split(os.sep)[-2]
predicted_label = predicted_labels[i] if predicted_labels else true_label
color = “green” if true_label == predicted_label else “red”
ax.ravel()[i].imshow(image)
ax.ravel()[i].set_title(predicted_label, color=color)
ax.ravel()[i].set_axis_off()
plt.tight_layout()
plt.show()
display_image_grid(test_images_filepaths)
定义一个PyTorch数据集类
===============
接下来,我们定义一个PyTorch数据集。 如果您不熟悉PyTorch数据集,请参阅本教程-https://pytorch.org/tutorials/beginner/data_loading_tutorial.html。 输出任务是二进制分类-模型需要预测图像包含猫还是狗。 我们的标签将标记图像包含猫的可能性。 因此,带有猫的图像的正确标签将为1.0,带有狗的图像的正确标签将为0.0。 __init__将收到一个可选的转换参数。 它是“白化”增强管道的转换功能。 然后在__getitem__中,Dataset类将使用该函数来扩大图像并返回正确的标签。
class CatsVsDogsDataset(Dataset):
def init(self, images_filepaths, transform=None):
self.images_filepaths = images_filepaths
self.transform = transform
def len(self):
return len(self.images_filepaths)
def getitem(self, idx):
image_filepath = self.images_filepaths[idx]
image = cv2.imread(image_filepath)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
if os.path.normpath(image_filepath).split(os.sep)[-2] == “Cat”:
label = 1.0
else:
label = 0.0
if self.transform is not None:
image = self.transform(image=image)[“image”]
return image, label
使用Albumentations定义训练和验证数据集的转换函数
===============================
我们使用Albumentation定义用于训练和验证数据集的扩充管道。在这两个管道中,我们首先调整输入图像的大小,因此其最小尺寸为160px,然后进行128px x 128px的裁剪。对于训练数据集,我们还对该作物应用更多的增强。接下来,我们将对图像进行归一化。我们首先将图像的所有像素值除以255,因此每个像素的值将在[0.0,1.0]范围内。然后,我们将减去平均像素值,然后将其除以标准偏差。增强流水线的均值和标准差取自ImageNet数据集。尽管如此,它们仍然可以很好地传输到``猫与狗’'数据集。之后,我们将应用ToTensorV2将Tombs数组转换为PyTorch张量,该张量将用作神经网络的输入。 请注意,在验证管道中,我们将使用A.CenterCrop而不是A.RandomCrop,因为我们希望验证结果具有确定性(这样就不会依赖于作物的随机位置)。
train_transform = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
A.RandomCrop(height=128, width=128),
A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
A.RandomBrightnessContrast(p=0.5),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
train_dataset = CatsVsDogsDataset(images_filepaths=train_images_filepaths, transform=train_transform)
val_transform = A.Compose(
[
A.SmallestMaxSize(max_size=160),
A.CenterCrop(height=128, width=128),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
]
)
val_dataset = CatsVsDogsDataset(images_filepaths=val_images_filepaths, transform=val_transform)
还让我们定义一个函数,该函数采用数据集并可视化应用于同一图像的不同增强。
def visualize_augmentations(dataset, idx=0, samples=10, cols=5):
dataset = copy.deepcopy(dataset)
dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])
rows = samples // cols
figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))
for i in range(samples):
image, _ = dataset[idx]
ax.ravel()[i].imshow(image)
ax.ravel()[i].set_axis_off()
plt.tight_layout()
plt.show()
random.seed(42)
visualize_augmentations(train_dataset)
定义训练辅助方法
========
我们定义了训练的辅助方法。 compute_accuracy接受模型预测和真实标签,并将返回这些预测的准确性。 MetricMonitor有助于跟踪训练和验证过程中的准确性或损失等指标
def calculate_accuracy(output, target):
output = torch.sigmoid(output) >= 0.5
target = target == 1.0
return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()
class MetricMonitor:
def init(self, float_precision=3):
self.float_precision = float_precision