PyTorch 和 Albumentations 实现图像分类(猫狗大战)

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

true_label = os.path.normpath(image_filepath).split(os.sep)[-2]

predicted_label = predicted_labels[i] if predicted_labels else true_label

color = “green” if true_label == predicted_label else “red”

ax.ravel()[i].imshow(image)

ax.ravel()[i].set_title(predicted_label, color=color)

ax.ravel()[i].set_axis_off()

plt.tight_layout()

plt.show()

display_image_grid(test_images_filepaths)

定义一个PyTorch数据集类

===============

接下来,我们定义一个PyTorch数据集。 如果您不熟悉PyTorch数据集,请参阅本教程-https://pytorch.org/tutorials/beginner/data_loading_tutorial.html。 输出任务是二进制分类-模型需要预测图像包含猫还是狗。 我们的标签将标记图像包含猫的可能性。 因此,带有猫的图像的正确标签将为1.0,带有狗的图像的正确标签将为0.0。 __init__将收到一个可选的转换参数。 它是“白化”增强管道的转换功能。 然后在__getitem__中,Dataset类将使用该函数来扩大图像并返回正确的标签。

class CatsVsDogsDataset(Dataset):

def init(self, images_filepaths, transform=None):

self.images_filepaths = images_filepaths

self.transform = transform

def len(self):

return len(self.images_filepaths)

def getitem(self, idx):

image_filepath = self.images_filepaths[idx]

image = cv2.imread(image_filepath)

image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

if os.path.normpath(image_filepath).split(os.sep)[-2] == “Cat”:

label = 1.0

else:

label = 0.0

if self.transform is not None:

image = self.transform(image=image)[“image”]

return image, label

使用Albumentations定义训练和验证数据集的转换函数

===============================

我们使用Albumentation定义用于训练和验证数据集的扩充管道。在这两个管道中,我们首先调整输入图像的大小,因此其最小尺寸为160px,然后进行128px x 128px的裁剪。对于训练数据集,我们还对该作物应用更多的增强。接下来,我们将对图像进行归一化。我们首先将图像的所有像素值除以255,因此每个像素的值将在[0.0,1.0]范围内。然后,我们将减去平均像素值,然后将其除以标准偏差。增强流水线的均值和标准差取自ImageNet数据集。尽管如此,它们仍然可以很好地传输到``猫与狗’'数据集。之后,我们将应用ToTensorV2将Tombs数组转换为PyTorch张量,该张量将用作神经网络的输入。 请注意,在验证管道中,我们将使用A.CenterCrop而不是A.RandomCrop,因为我们希望验证结果具有确定性(这样就不会依赖于作物的随机位置)。

train_transform = A.Compose(

[

A.SmallestMaxSize(max_size=160),

A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),

A.RandomCrop(height=128, width=128),

A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),

A.RandomBrightnessContrast(p=0.5),

A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

ToTensorV2(),

]

)

train_dataset = CatsVsDogsDataset(images_filepaths=train_images_filepaths, transform=train_transform)

val_transform = A.Compose(

[

A.SmallestMaxSize(max_size=160),

A.CenterCrop(height=128, width=128),

A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),

ToTensorV2(),

]

)

val_dataset = CatsVsDogsDataset(images_filepaths=val_images_filepaths, transform=val_transform)

还让我们定义一个函数,该函数采用数据集并可视化应用于同一图像的不同增强。

def visualize_augmentations(dataset, idx=0, samples=10, cols=5):

dataset = copy.deepcopy(dataset)

dataset.transform = A.Compose([t for t in dataset.transform if not isinstance(t, (A.Normalize, ToTensorV2))])

rows = samples // cols

figure, ax = plt.subplots(nrows=rows, ncols=cols, figsize=(12, 6))

for i in range(samples):

image, _ = dataset[idx]

ax.ravel()[i].imshow(image)

ax.ravel()[i].set_axis_off()

plt.tight_layout()

plt.show()

random.seed(42)

visualize_augmentations(train_dataset)

定义训练辅助方法

========

我们定义了训练的辅助方法。 compute_accuracy接受模型预测和真实标签,并将返回这些预测的准确性。 MetricMonitor有助于跟踪训练和验证过程中的准确性或损失等指标

def calculate_accuracy(output, target):

output = torch.sigmoid(output) >= 0.5

target = target == 1.0

return torch.true_divide((target == output).sum(dim=0), output.size(0)).item()

class MetricMonitor:

def init(self, float_precision=3):

self.float_precision = float_precision

  • 22
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值