Pytorch框架下的Triplet Loss(包含训练流程)

最近机器学习的课程中接触到了图像检索任务,也顺便了解到了Triplet Loss,网上有很多关于理论的介绍,这里就不多赘述了,本篇着重关注当我们应用到Triplet形式的数据时,如何对数据进行封装,训练。

本篇以CIFAR-10的图像检索实验为例,使用Triplet Loss微调ResNet18。(由于对这个领域接触较少,有错误的地方烦请大佬们指正)。

废话不多说,见代码

import torch
from torch.nn.functional import pairwise_distance
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.neighbors import NearestNeighbors
import warnings
import random
warnings.filterwarnings('ignore')

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])

trainset = torchvision.datasets.CIFAR10(root='./data',
                                        download=False, transform=transform)   # 仅作演示,这里不采用测试数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
                                          shuffle=False, num_workers=2)

class Triplet_dataset(torch.utils.data.Dataset):   # 自己定义一个Triplet类,要求能返回三元组数据
    def __init__(self, data):
        # 接受trainset为输入, 获取全部的图像数据以及标签数据
        # 因为方便处理,实际上应用自己的数据时完全可以跳过trainset这一步
        self.data = data
        self.image = [torch.tensor(data[i][0]) for i in range(len(self.data))]
        self.label = [torch.tensor(data[i][1]) for i in range(len(self.data))]
    def __len__(self):
        return len(self.data)
    def __getitem__(self, item):
        label = self.label[item]
        anchor = self.image[item]  # anchor为索引的图像, positive为标签中同类的图像,negative为标签不同类的图像数据
        positive = random.choice(self.image[self.label == label])
        negative = random.choice(self.image[self.label != label])
        return anchor, positive, negative

class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
               'dog', 'frog', 'horse', 'ship', 'truck']

 以下为观察数据集

import numpy as np
def imshow(img):
    img = img / 2 + 0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))

 

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        resnet = resnet18(pretrained=True)
        self.resnet = nn.Sequential(*list(resnet.children())[:-1])  # 定义预训练模型
    def forward(self, x):
        x = self.resnet(x)
        x = torch.flatten(x, 1)
        return x
model = Net().cuda()
"""这里就是稍微重要的处理步骤了,自己定义一个生成器函数,封装Triplet_dataset"""
triplet_data = Triplet_dataset(trainset)
batch_size = 32
def loader_data():
    for k in range(len(triplet_data) // batch_size - 1):
        anchor, positive, negative = torch.empty(size=(32, 3, 32, 32)), torch.empty(size=(32, 3, 32, 32)), \
            torch.empty(size=(32, 3, 32, 32))   # 空的容器
        for i in range(k * 32, (k + 1) * 32):
            anchor[i - 32*k], positive[i - 32*k], negative[i - 32*k] = triplet_data[i]
        yield anchor, positive, negative
class TripletLoss(nn.Module):   # 定义Triplet Loss,距离度量采用余弦损失
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_dist = pairwise_distance(anchor, positive)
        neg_dist = pairwise_distance(anchor, negative)
        loss = torch.relu(pos_dist - neg_dist + self.margin)
        return loss.mean()
# 初始化损失函数和优化器
criterion = TripletLoss(margin=1.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

def train_model(model, criterion, optimizer, num_epochs=25):
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        dataloader = loader_data()
        for i in range(len(triplet_data) // batch_size - 1):
            anchor, positive, negative = next(dataloader)
            anchor, positive, negative = anchor.cuda(), positive.cuda(), negative.cuda()
            optimizer.zero_grad()
            anchor_out = model(anchor)
            positive_out = model(positive)
            negative_out = model(negative)
            loss = criterion(anchor_out, positive_out, negative_out)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            if i % 10 == 9:
                print(f"[Epoch {epoch+1}, Batch {i+1}] loss: {running_loss / 10:.3f}")
                running_loss = 0.0
train_model(model, criterion, optimizer, num_epochs=1)

 可以看到,模型逐渐趋于收敛。

model.eval()
"""进行特征的提取,并处理数据"""
def get_data(loader):
    features, labels = [], []
    for batch_idx, (image, label) in enumerate(loader):
        input = image.cuda()
        with torch.no_grad():
            output = model(input).squeeze().detach().cpu().numpy()
        features.append(output)
        labels.append(label.reshape(-1).detach().numpy())
    features = np.concatenate(features, axis=0)
    labels = np.concatenate(labels)
    return features, labels
train_feature, train_label = get_data(trainloader)
#%%
nearest_model= NearestNeighbors(n_neighbors=8, metric='cosine')  # 用余弦距离训练一个最近邻
nearest_model.fit(train_feature, train_label)
def retrieve_top_k(query_feature):
    distances, indices = nearest_model.kneighbors(query_feature)
    return indices[0]
query_idx = 100  # 假设我们用第一个测试图像作为查询
query_feature = train_feature[query_idx]
top_k_indices = retrieve_top_k([query_feature])
# 打印检索结果
print(f"Query Image Label: {train_label[query_idx]}")
print(f"Top-10 Retrieved Indices: {top_k_indices}")
print(f"Top-10 Retrieved Labels: {train_label[top_k_indices]}")
query_img = trainset[query_idx][0]
query_img = query_img / 2 + 0.5
img = query_img.numpy()
plt.xlabel(class_names[train_label[query_idx]])
plt.imshow(np.transpose(img, (1, 2, 0)))

n_row, n_col = 2, 4
plt.figure(figsize=(8, 8))
for i in range(1, n_row * n_col + 1):
        plt.subplot(2, 4, i)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.xlabel(class_names[train_label[top_k_indices[i-1]]])
        img = trainset[top_k_indices[i-1]][0]
        img = img / 2 + 0.5
        npimg = img.numpy()
        plt.imshow(np.transpose(npimg, (1, 2, 0)))

 博主对比了应用Triplet Loss微调ResNet的方法与直接调用预训练ResNet的方法,微调后的模型训练出来的特征更容易寻找到相关图片。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

藤宫博野

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值