面试突击班PyTorch 和 Albumentations 实现图像分类(猫狗大战,Python面试真题解析火爆全网

本文介绍了如何使用PyTorch和Albumentations库对猫狗图像进行分类。通过定义不同的训练和验证数据集增强策略,如随机作物、亮度对比度调整和归一化,来训练ResNet50模型。在训练过程中,使用BCEWithLogitsLoss损失函数和Adam优化器,展示了训练和验证的准确性和损失。文章还提供了一个辅助函数用于可视化图像增强效果,并给出了完整的训练流程。
摘要由CSDN通过智能技术生成

我们使用Albumentation定义用于训练和验证数据集的扩充管道。在这两个管道中,我们首先调整输入图像的大小,因此其最小尺寸为160px,然后进行128px x 128px的裁剪。对于训练数据集,我们还对该作物应用更多的增强。接下来,我们将对图像进行归一化。我们首先将图像的所有像素值除以255,因此每个像素的值将在[0.0,1.0]范围内。然后,我们将减去平均像素值,然后将其除以标准偏差。增强流水线的均值和标准差取自ImageNet数据集。尽管如此,它们仍然可以很好地传输到``猫与狗’'数据集。之后,我们将应用ToTensorV2将Tombs数组转换为PyTorch张量,该张量将用作神经网络的输入。 请注意,在验证管道中,我们将使用A.CenterCrop而不是A.RandomCrop,因为我们希望验证结果具有确定性(这样就不会依赖于作物的随机位置)。

train_transform = A.Compose(

[

A.SmallestMaxSize(max_size=160),

A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),

A.RandomCrop(height=128, width=128),

A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),

A.RandomBrightnessContrast(p=0.5),

A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

ToTensorV2(),

]

)

train_dataset = CatsVsDogsDataset(images_filepaths=train_images_filepaths, transform=train_transform)

val_transform = A.Compose(

[

A.SmallestMaxSize(max_size=160),

A.CenterCrop(height=128, width=128),

A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

ToTensorV2(),

]

)

val_dataset = CatsVsDogsDataset(images_filepaths=val_images_filepaths, transform=val_transform)

还让我们定义一个函数,该函数采用数据集并可视化应用于同一图像的不同增强。

def visualize_augmentations(dataset, idx=0, samples=10, cols=5):

dataset = copy.deepcopy(dataset)

dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])

rows = samples // cols

figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))

for i in range(samples):

image, _ = dataset[idx]

ax.ravel()[i].imshow(image)

ax.ravel()[i].set_axis_off()

plt.tight_layout()

plt.show()

random.seed(42)

visualize_augmentations(train_dataset)

定义训练辅助方法

========

我们定义了训练的辅助方法。 compute_accuracy接受模型预测和真实标签,并将返回这些预测的准确性。 MetricMonitor有助于跟踪训练和验证过程中的准确性或损失等指标

def calculate_accuracy(output, target):

output = torch.sigmoid(output) >= 0.5

target = target == 1.0

return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()

class MetricMonitor:

def init(self, float_precision=3):

self.float_precision = float_precision

self.reset()

def reset(self):

self.metrics = defaultdict(lambda: {“val”: 0, “count”: 0, “avg”: 0})

def update

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值