最近机器学习的课程中接触到了图像检索任务,也顺便了解到了Triplet Loss,网上有很多关于理论的介绍,这里就不多赘述了,本篇着重关注当我们应用到Triplet形式的数据时,如何对数据进行封装,训练。
本篇以CIFAR-10的图像检索实验为例,使用Triplet Loss微调ResNet18。(由于对这个领域接触较少,有错误的地方烦请大佬们指正)。
废话不多说,见代码
import torch
from torch.nn.functional import pairwise_distance
import torchvision
import torchvision.transforms as transforms
from torchvision.models import resnet18
import matplotlib.pyplot as plt
import torch.nn as nn
from sklearn.neighbors import NearestNeighbors
import warnings
import random
warnings.filterwarnings('ignore')
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
trainset = torchvision.datasets.CIFAR10(root='./data',
download=False, transform=transform) # 仅作演示,这里不采用测试数据
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=False, num_workers=2)
class Triplet_dataset(torch.utils.data.Dataset): # 自己定义一个Triplet类,要求能返回三元组数据
def __init__(self, data):
# 接受trainset为输入, 获取全部的图像数据以及标签数据
# 因为方便处理,实际上应用自己的数据时完全可以跳过trainset这一步
self.data = data
self.image = [torch.tensor(data[i][0]) for i in range(len(self.data))]
self.label = [torch.tensor(data[i][1]) for i in range(len(self.data))]
def __len__(self):
return len(self.data)
def __getitem__(self, item):
label = self.label[item]
anchor = self.image[item] # anchor为索引的图像, positive为标签中同类的图像,negative为标签不同类的图像数据
positive = random.choice(self.image[self.label == label])
negative = random.choice(self.image[self.label != label])
return anchor, positive, negative
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck']
以下为观察数据集
import numpy as np
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
# get some random training images
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(torchvision.utils.make_grid(images))
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
resnet = resnet18(pretrained=True)
self.resnet = nn.Sequential(*list(resnet.children())[:-1]) # 定义预训练模型
def forward(self, x):
x = self.resnet(x)
x = torch.flatten(x, 1)
return x
model = Net().cuda()
"""这里就是稍微重要的处理步骤了,自己定义一个生成器函数,封装Triplet_dataset"""
triplet_data = Triplet_dataset(trainset)
batch_size = 32
def loader_data():
for k in range(len(triplet_data) // batch_size - 1):
anchor, positive, negative = torch.empty(size=(32, 3, 32, 32)), torch.empty(size=(32, 3, 32, 32)), \
torch.empty(size=(32, 3, 32, 32)) # 空的容器
for i in range(k * 32, (k + 1) * 32):
anchor[i - 32*k], positive[i - 32*k], negative[i - 32*k] = triplet_data[i]
yield anchor, positive, negative
class TripletLoss(nn.Module): # 定义Triplet Loss,距离度量采用余弦损失
def __init__(self, margin=1.0):
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self, anchor, positive, negative):
pos_dist = pairwise_distance(anchor, positive)
neg_dist = pairwise_distance(anchor, negative)
loss = torch.relu(pos_dist - neg_dist + self.margin)
return loss.mean()
# 初始化损失函数和优化器
criterion = TripletLoss(margin=1.0)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
def train_model(model, criterion, optimizer, num_epochs=25):
model.train()
for epoch in range(num_epochs):
running_loss = 0.0
dataloader = loader_data()
for i in range(len(triplet_data) // batch_size - 1):
anchor, positive, negative = next(dataloader)
anchor, positive, negative = anchor.cuda(), positive.cuda(), negative.cuda()
optimizer.zero_grad()
anchor_out = model(anchor)
positive_out = model(positive)
negative_out = model(negative)
loss = criterion(anchor_out, positive_out, negative_out)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 10 == 9:
print(f"[Epoch {epoch+1}, Batch {i+1}] loss: {running_loss / 10:.3f}")
running_loss = 0.0
train_model(model, criterion, optimizer, num_epochs=1)
可以看到,模型逐渐趋于收敛。
model.eval()
"""进行特征的提取,并处理数据"""
def get_data(loader):
features, labels = [], []
for batch_idx, (image, label) in enumerate(loader):
input = image.cuda()
with torch.no_grad():
output = model(input).squeeze().detach().cpu().numpy()
features.append(output)
labels.append(label.reshape(-1).detach().numpy())
features = np.concatenate(features, axis=0)
labels = np.concatenate(labels)
return features, labels
train_feature, train_label = get_data(trainloader)
#%%
nearest_model= NearestNeighbors(n_neighbors=8, metric='cosine') # 用余弦距离训练一个最近邻
nearest_model.fit(train_feature, train_label)
def retrieve_top_k(query_feature):
distances, indices = nearest_model.kneighbors(query_feature)
return indices[0]
query_idx = 100 # 假设我们用第一个测试图像作为查询
query_feature = train_feature[query_idx]
top_k_indices = retrieve_top_k([query_feature])
# 打印检索结果
print(f"Query Image Label: {train_label[query_idx]}")
print(f"Top-10 Retrieved Indices: {top_k_indices}")
print(f"Top-10 Retrieved Labels: {train_label[top_k_indices]}")
query_img = trainset[query_idx][0]
query_img = query_img / 2 + 0.5
img = query_img.numpy()
plt.xlabel(class_names[train_label[query_idx]])
plt.imshow(np.transpose(img, (1, 2, 0)))
n_row, n_col = 2, 4
plt.figure(figsize=(8, 8))
for i in range(1, n_row * n_col + 1):
plt.subplot(2, 4, i)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.xlabel(class_names[train_label[top_k_indices[i-1]]])
img = trainset[top_k_indices[i-1]][0]
img = img / 2 + 0.5
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
博主对比了应用Triplet Loss微调ResNet的方法与直接调用预训练ResNet的方法,微调后的模型训练出来的特征更容易寻找到相关图片。