import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
class VGGBase(nn.Module):
def __init__(self):
super(VGGBase, self).__init__()
vgg = torchvision.models.vgg16(pretrained=True)
self.features = nn.Sequential(*list(vgg.features)[:-2])
def forward(self, x):
x = self.features(x)
return x
class PredictionLayers(nn.Module):
def __init__(self, num_classes):
super(PredictionLayers, self).__init__()
self.num_classes = num_classes
self.loc_layers = nn.ModuleList()
self.conf_layers = nn.ModuleList()
self.loc_layers.append(nn.Conv2d(512, 4 * 4, kernel_size=3, padding=1))
self.conf_layers.append(nn.Conv2d(512, 4 * num_classes, kernel_size=3, padding=1))
self.loc_layers.append(nn.Conv2d(1024, 6 * 4, kernel_size=3, padding=1))
self.conf_layers.append(nn.Conv2d(1024, 6 * num_classes, kernel_size=3, padding=1))
self.loc_layers.append(nn.Conv2d(512, 6 * 4, kernel_size=3, padding=1))
self.conf_layers.append(nn.Conv2d(512, 6 * num_classes, kernel_size=3, padding=1))
self.loc_layers.append(nn.Conv2d(256, 6 * 4, kernel_size=3, padding=1))
self.conf_layers.append(nn.Conv2d(256, 6 * num_classes, kernel_size=3, padding=1))
def forward(self, features):
loc_preds = []
conf_preds = []
for (x, l, c) in zip(features, self.loc_layers, self.conf_layers):
loc_preds.append(l(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, 4))
conf_preds.append(c(x).permute(0, 2, 3, 1).contiguous().view(x.size(0), -1, self.num_classes))
loc_preds = torch.cat(loc_preds, 1)
conf_preds = torch.cat(conf_preds, 1)
return loc_preds, conf_preds
'''
loc_preds = [
[
[[0.1, 0.2, 0.3, 0.4], [0.2, 0.3, 0.4, 0.5], [0.3, 0.4, 0.5, 0.6]],
[[0.4, 0.5, 0.6, 0.7], [0.5, 0.6, 0.7, 0.8], [0.6, 0.7, 0.8, 0.9]]
], 这是第一层的框预测结果
[
[[0.2, 0.3, 0.4, 0.5], [0.3, 0.4, 0.5, 0.6], [0.4, 0.5, 0.6, 0.7]],
[[0.5, 0.6, 0.7, 0.8], [0.6, 0.7, 0.8, 0.9], [0.7, 0.8, 0.9, 1.0]]
]
]
conf_preds = [
[
[[0.9, 0.1], [0.7, 0.3], [0.6, 0.4]],
[[0.8, 0.2], [0.5, 0.5], [0.4, 0.6]]
], 这是第一层的分类预测结果,以两类分类为例
[
[[0.7, 0.3], [0.6, 0.4], [0.5, 0.5]],
[[0.6, 0.4], [0.4, 0.6], [0.3, 0.7]]
]
]
'''
class SSD(nn.Module):
def __init__(self, num_classes):
super(SSD, self).__init__()
self.num_classes = num_classes
self.base = VGGBase()
self.pred = PredictionLayers(num_classes)
def forward(self, x):
x = self.base(x)
loc_preds, conf_preds = self.pred([x])
return loc_preds, conf_preds
def decode_predictions(loc_preds, conf_preds, threshold=0.5):
boxes = []
labels = []
scores = []
conf_scores = torch.softmax(conf_preds, dim=-1)
for i in range(conf_scores.shape[1]):
score, label = conf_scores[0, i].max(0)
if score > threshold:
box = loc_preds[0, i].cpu().numpy()
boxes.append(box)
labels.append(label.item())
scores.append(score.item())
return boxes, labels, scores
def to_original_scale(box, height, width):
ymin, xmin, ymax, xmax = box
ymin *= height
ymax *= height
xmin *= width
xmax *= width
return int(xmin), int(ymin), int(xmax), int(ymax)
def draw_boxes(img, boxes, labels, scores):
for box, label, score in zip(boxes, labels, scores):
xmin, ymin, xmax, ymax = to_original_scale(box, img.shape[0], img.shape[1])
cv2.rectangle(img, (xmin, ymin), (xmax, ymax), (0, 255, 0), 2)
text = f'{label}: {score:.2f}'
cv2.putText(img, text, (xmin, ymin-10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
return img
if __name__ == "__main__":
num_classes = 21
model = SSD(num_classes)
img = cv2.imread('example.jpg')
height, width, _ = img.shape
transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((300, 300)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_img = transform(img).unsqueeze(0)
model.eval()
with torch.no_grad():
loc_preds, conf_preds = model(input_img)
boxes, labels, scores = decode_predictions(loc_preds, conf_preds)
result_img = draw_boxes(img, boxes, labels, scores)
cv2.imshow('Detected Objects', result_img)
cv2.waitKey(0)
cv2.destroyAllWindows()
损失函数
import torch
import torch.nn as nn
import torch.nn.functional as F
class SSDDetectionLoss(nn.Module):
def __init__(self, num_classes, alpha=1.0, neg_pos_ratio=3):
super(SSDDetectionLoss, self).__init__()
self.num_classes = num_classes
self.alpha = alpha
self.neg_pos_ratio = neg_pos_ratio
def smooth_l1_loss(self, preds, targets, beta=1.0):
diff = torch.abs(preds - targets)
smooth_l1_loss = torch.where(diff < beta, 0.5 * diff ** 2 / beta, diff - 0.5 * beta)
return smooth_l1_loss.mean()
def cross_entropy_loss(self, preds, targets):
num_positives = targets.ne(-1).sum()
conf_loss = F.cross_entropy(preds.view(-1, self.num_classes), targets.view(-1), reduction='none')
pos_mask = targets.ge(0)
neg_mask = targets.eq(-1)
num_negatives = torch.clamp(self.neg_pos_ratio * num_positives, max=targets.size(1) - 1)
sorted_conf_loss = torch.argsort(conf_loss, descending=True)
neg_mask[sorted_conf_loss[:num_negatives]] = 1
total_conf_loss = (conf_loss * (pos_mask + neg_mask)).sum() / num_positives.clamp(min=1)
return total_conf_loss
def forward(self, loc_preds, loc_targets, conf_preds, conf_targets):
loc_loss = self.smooth_l1_loss(loc_preds, loc_targets)
conf_loss = self.cross_entropy_loss(conf_preds, conf_targets)
total_loss = self.alpha * loc_loss + conf_loss
return total_loss
num_classes = 21
loss_fn = SSDDetectionLoss(num_classes)
loss = loss_fn(loc_preds, loc_targets, conf_preds, conf_targets)