import torch
import torch.nn as nn
from utils import intersection_over_union
classYoloLoss(nn.Module):"""
Calculate the loss for yolo (v1) model
"""def__init__(self, S=7, B=2, C=20):super(YoloLoss, self).__init__()
self.mse = nn.MSELoss(reduction="sum")"""
S is split size of image (in paper 7),
B is number of boxes (in paper 2),
C is number of classes (in paper and VOC dataset is 20),
"""
self.S = S
self.B = B
self.C = C
# These are from Yolo paper, signifying how much we should# pay loss for no object (noobj) and the box coordinates (coord)
self.lambda_noobj =0.5
self.lambda_coord =5defforward(self, predictions, target):# predictions are shaped (BATCH_SIZE, S*S(C+B*5) when inputted
predictions = predictions.reshape(-1, self.S, self.S, self.C + self.B *5)# Calculate IoU for the two predicted bounding boxes with target bbox
iou_b1 = intersection_over_union(predictions[...,21:25], target[...,21:25])
iou_b2 = intersection_over_union(predictions[...,26:30], target[...,21:25])
ious = torch.cat([iou_b1.unsqueeze(0), iou_b2.unsqueeze(0)], dim=0)# Take the box with highest IoU out of the two prediction# Note that bestbox will be indices of 0, 1 for which bbox was best
iou_maxes, best_box = torch.max(ious, dim=0)
exists_box = target[...,20].unsqueeze(3)# in paper this is Iobj_i# ======================== ## FOR BOX COORDINATES ## ======================== ## Set boxes with no object in them to 0. We only take out one of the two # predictions, which is the one with highest Iou calculated previously.
box_predictions = exists_box *((
best_box * predictions[...,26:30]+(1- best_box)* predictions[...,21:25]))
box_targets = exists_box * target[...,21:25]# Take sqrt of width, height of boxes to ensure that
box_predictions[...,2:4]= torch.sign(box_predictions[...,2:4])* torch.sqrt(
torch.abs(box_predictions[...,2:4]+1e-6))
box_targets[...,2:4]= torch.sqrt(box_targets[...,2:4])
box_loss = self.mse(
torch.flatten(box_predictions, end_dim=-2),
torch.flatten(box_targets, end_dim=-2),)# ==================== ## FOR OBJECT LOSS ## ==================== ## pred_box is the confidence score for the bbox with highest IoU
pred_box =(
best_box * predictions[...,25:26]+(1- best_box)* predictions[...,20:21])
object_loss = self.mse(
torch.flatten(exists_box * pred_box),
torch.flatten(exists_box * target[...,20:21]),)# ======================= ## FOR NO OBJECT LOSS ## ======================= ##max_no_obj = torch.max(predictions[..., 20:21], predictions[..., 25:26])#no_object_loss = self.mse(# torch.flatten((1 - exists_box) * max_no_obj, start_dim=1),# torch.flatten((1 - exists_box) * target[..., 20:21], start_dim=1),#)
no_object_loss = self.mse(
torch.flatten((1- exists_box)* predictions[...,20:21], start_dim=1),
torch.flatten((1- exists_box)* target[...,20:21], start_dim=1),)
no_object_loss += self.mse(
torch.flatten((1- exists_box)* predictions[...,25:26], start_dim=1),
torch.flatten((1- exists_box)* target[...,20:21], start_dim=1))# ================== ## FOR CLASS LOSS ## ================== #
class_loss = self.mse(
torch.flatten(exists_box * predictions[...,:20], end_dim=-2,),
torch.flatten(exists_box * target[...,:20], end_dim=-2,),)
loss =(
self.lambda_coord * box_loss # first two rows in paper+ object_loss # third row in paper+ self.lambda_noobj * no_object_loss # forth row+ class_loss # fifth row)return loss
模型训练 Training
模型的保存,check_point和save_point起着读取和保存训练模型的作用。
import torch
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.transforms.functional as FT
from torch.utils.data import DataLoader
from loss import YoloLoss
from model import Yolov1
from dataset import VOCDataset
from tqdm import tqdm
from utils import(
intersection_over_union,
non_max_suppression,
mean_average_precision,
cellboxes_to_boxes,
get_bboxes,
plot_image,
save_checkpoint,
load_checkpoint,)
seed =123
torch.manual_seed(seed)# Hyperparameters etc.
LEARN_RATE =2e-5
DEVICE = torch.device('cuda'if torch.cuda.is_available()else'cpu')
BATCH_SIZE =4
WEIGHT_DECAY =0
EPOCHS =100
NUM_WORKERS =2
PIN_MEMORY =True
LOAD_MODEL =False
LOAD_MODEL_FILE ='overfit.pth.tar'
IMG_DIR ='data/images'
LABEL_DIR ='data/labels'classCompose(object):def__init__(self, transforms):
self.transforms = transforms
def__call__(self, img, bboxes):for t in self.transforms:
img, bboxes = t(img), bboxes
return img, bboxes
transform = Compose([
transforms.Resize((448,448)), transforms.ToTensor()])deftrain_fn(train_loader, model, optimizer, loss_fn):
loop = tqdm(train_loader, leave=True)
mean_loss =[]for batch_idx,(x, y)inenumerate(loop):
x, y = x.to(DEVICE), y.to(DEVICE)
out = model(x)
loss = loss_fn(out, y)
mean_loss.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()# Update the progress bar
loop.set_postfix(loss = loss.item())print(f"Mean loss war {sum(mean_loss)/len(mean_loss)}")defmain():
model = Yolov1(split_size=7, num_boxes=2, num_classes=20).to(DEVICE)
optimizer = optim.Adam(
model.parameters(), lr=LEARN_RATE, weight_decay=WEIGHT_DECAY
)
loss_fn = YoloLoss()if LOAD_MODEL:
load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)
train_dataset = VOCDataset("data/8examples.csv",
transform=transform,
img_dir=IMG_DIR,
label_dir=LABEL_DIR,)
test_dataset = VOCDataset("data/test.csv",
transform=transform,
img_dir=IMG_DIR,
label_dir=LABEL_DIR,)
train_loader = DataLoader(
dataset=train_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY,
shuffle=True,
drop_last=False,)
test_loader = DataLoader(
dataset=test_dataset,
batch_size=BATCH_SIZE,
num_workers=NUM_WORKERS,
pin_memory=PIN_MEMORY,
shuffle=True,
drop_last=True,)for epoch inrange(EPOCHS):# for x, y in train_loader:# x = x.to(DEVICE)# for idx in range(8):# bboxes = cellboxes_to_boxes(model(x))# bboxes = non_max_suppression(bboxes[idx], iou_threshold=0.5, threshold=0.4, box_format="midpoint")# plot_image(x[idx].permute(1,2,0).to("cpu"), bboxes)# import sys# sys.exit()
pred_boxes, target_boxes = get_bboxes(
train_loader, model, iou_threshold=0.5, threshold=0.4)
mean_avg_prec = mean_average_precision(
pred_boxes, target_boxes, iou_threshold=0.5, box_format='midpoint')print(f"Train mAP: {mean_avg_prec}")#if mean_avg_prec > 0.9:# checkpoint = {# "state_dict": model.state_dict(),# "optimizer": optimizer.state_dict(),# }# save_checkpoint(checkpoint, filename=LOAD_MODEL_FILE)# import time# time.sleep(10)
train_fn(train_loader, model, optimizer, loss_fn)# 做程序测试用if __name__ =="__main__":
main()