In [2]:
%%capture
import matplotlib.pyplot as plt
from celluloid import Camera # pour l'animation
from IPython.display import HTML # pour la vidéo
import numpy as np

def ligne(p1, p2):
    plt.plot([p1[0], p2[0]], [p1[1], p2[1]], "b:", linewidth = 1)

def dist(p1, p2):
    return np.sum((p1 - p2)**2)**.5
    
def kmeans(k, X, n):
    '''
    k: nombre de classes
    X: données
    n: nombre d'itérations
    '''
    Y = np.zeros((len(X)), dtype = np.int64) # classes
    fig = plt.figure()
    camera = Camera(fig)
    rng = np.random.default_rng()
    centres = rng.choice(X, size = k, replace = False) # centres choisis initialement parmi les données (d'autres choix sont possibles)
    plt.scatter(X[:, 0], X[:, 1], c = "blue")
    camera.snap()
    plt.scatter(X[:, 0], X[:, 1], c = "blue")
    plt.scatter(centres[:, 0], centres[:, 1], s = 200, c = "orange")
    camera.snap()
    for i in range(n):
        plt.scatter(X[:, 0], X[:, 1], c = "blue")
        plt.scatter(centres[:, 0], centres[:, 1], s = 200, c = "orange")
        for j in range(len(X)):
            Y[j] = np.argmin([dist(c, X[j]) for c in centres]) # X[j] est assigné au centre le plus proche
            ligne(centres[Y[j]], X[j])
        camera.snap()
        centres = np.array([X[Y == i].mean(0) for i in range(k)])
    return camera.animate(interval = 2000, repeat = True, repeat_delay = 500)
            
X1 = np.array([2, 2]) + np.random.randn(10, 2)
X2 = np.array([-2, -2]) + np.random.randn(10, 2)
X3 = np.array([-2, 2]) + np.random.randn(10, 2)
X4 = np.array([2, -2]) + np.random.randn(10, 2)

anim = kmeans(4, np.concatenate((X1, X2, X3, X4)), 5)
In [4]:
HTML(anim.to_html5_video())
Out[4]: