import torch
from nets.yolo import YoloBody
from utils.utils import get_classes
classes_path = 'model_data/voc_classes.txt'
class_names, num_classes = get_classes(classes_path)
model = YoloBody(num_classes)
model_path = 'logs/last_epoch_weights.pth'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load(model_path, map_location=device))
example = torch.rand(1, 3, 640, 640)
model.eval()
output = torch.jit.trace(model, example)
torch.jit.save(output, "yolo.pt")