文章中谈论的SVM是 Multiclass Support Vector Machine
Multiclass Support Vector Loss
The SVM loss is set up so that the SVM “wants” the correct class for each image to a have a score higher than the incorrect classes by some fixed margin Δ. Notice that it’s sometimes helpful to anthropomorphise the loss functions as we did above: The SVM “wants” a certain outcome in the sense that the outcome would yield a lower loss (which is good).
也就是说:SVM loss 是为了SVM对于分类的image在正确类别上面的得分要比不正确的类别得分高出margin Δ。
公式1: ![L_i = \sum_{j \neq y_i}max(0, s_j - s_{y_i} + \Delta)](https://i-blog.csdnimg.cn/blog_migrate/ed7a4f51aacdff13843e7fc4c32476a8.gif)
:第 i 个样本的loss
:在第 j 个类别上的得分
:在正确类别上的得分
从公式中,我们中可以看出image在正确类别上的得分比不正确的类别的得分多Δ,那么当前不正确类别并不会对loss产生任何贡献。
公式2: ![L_i = \sum_{j \neq y_i}max(0, w_j^{T}*x - w_{y_i}^T*x + \Delta)](https://i-blog.csdnimg.cn/blog_migrate/9cec2c6193042ac3c248e8883ecc749d.gif)
Calculus
SVM Loss Fuction:
的梯度:
的梯度:
SVM代码
github:https://github.com/GIGpanda/CS231n
主要包括两个.py文件,svm.py和linear_svm.py
svm.py
数据加载、可视化
# Multiclass Support Vector Machine
from __future__ import print_function
import random
import numpy as np
from cs231n.data_utils import load_CIFAR10
import matplotlib.pyplot as plt
from cs231n.classifiers.linear_svm import svm_loss_naive
import time
from cs231n.gradient_check import grad_check_sparse
from cs231n.classifiers.linear_svm import svm_loss_vectorized
from cs231n.classifiers import LinearSVM
import math
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
# Load the raw CIFAR-10 data.
cifar10_dir = 'cs231n/datasets/cifar-10-batches-py'
# Cleaning up variables to prevent loading data multiple times (which may cause memory issue)
try:
del X_train, y_train
del X_test, y_test
print('Clear previously loaded data.')
except:
pass
X_train, y_train, X_test, y_test = load_CIFAR10(cifar10_dir)
# As a sanity check, we print out the size of the training and test data.
print('Training data shape: ', X_train.shape)
print('Training labels shape: ', y_train.shape)
print('Test data shape: ', X_test.shape)
print('Test labels shape: ', y_test.shape)
# Visualize some examples from the dataset.
# We show a few examples of