Support Vector Machine

Soit $X$ un ensemble de données composé de deux classes étiquetées par $Y$ (les étiquettes étant 1 ou -1).
SVM est une méthode de classification linéaire.
Il sépare des données par un hyperplan qui a pour particularité de maximiser la marge, c'est à dire la distance de l'hyperplan aux données.
En définissant un hyperplan affine en dimension $d$ par son équation cartésienne $w^T x + b = 0$ ($x \in \mathbb{R}^d$), on peut montrer que l'hyperplan de marge maximum est obtenu par le problème d'optimisation quadratique suivant: $$\min ~\lVert w \rVert^2$$ Sous les contraintes: $$\forall x_i \in X, ~~y_i(w^T x_i + b) \geq 1$$ La distance d'une donnée $x$ à l'hyperplan est alors $\frac{\vert w^T x + b \vert}{\lVert w \rVert} \geq \frac{1}{\lVert w \rVert}$, maximisé lorsque $\lVert w \rVert$ et donc $\lVert w \rVert^2$ est minimisé.

Voir l'excellent site http://wikistat.fr/ pour plus de détails.

In [1]:
import cvxpy as cp # optimisation convexe
import matplotlib.pyplot as plt
import numpy as np
In [2]:
def svm(X, Y, d): 
    constraints = []
    w = cp.Variable(d)
    b = cp.Variable(1)
    for i in range(len(X)):
        constraints.append(Y[i]*(X[i]@w + b) >= 1) # on utilise des vecteurs lignes pour les données
    objective = cp.Minimize(cp.sum_squares(w))
    prob = cp.Problem(objective, constraints)
    result = prob.solve()
    return w.value, b.value
In [3]:
n, d = 10, 3 # deux classes de n variables gaussiennes en dimension d
X1 = np.array([3, 1, 1]) + np.random.randn(n, d)
X2 = np.array([-1, -2, -2]) + np.random.randn(n, d) # attention: très léger risque que les données ne soient pas linéairement séparables
X = np.concatenate((X1, X2))
Y = [1]*n + [-1]*n
w_opt, b_opt = svm(X, Y, d)
In [4]:
x = np.linspace(-5,5,10)
y = np.linspace(-5,5,10)
X,Y = np.meshgrid(x,y)
Z = -(X*w_opt[0] + Y*w_opt[1] + b_opt)/w_opt[2]

# avec matplotlib: 
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
fig = plt.figure()
ax = plt.axes(projection='3d')
surf = ax.plot_wireframe(X, Y, Z, alpha = .4)
ax.scatter(X1[:, 0], X1[:, 1], X1[:, 2])
ax.scatter(X2[:, 0], X2[:, 1], X2[:, 2]);
Out[4]:
<mpl_toolkits.mplot3d.art3d.Path3DCollection at 0x7f7ae4440e50>