Gaussiennes - Apprentissage naïf

par Joseph Razik, le 2019-10-18
Gaussiennes
In [1]:
def gauss(mu, sigma, x):
    """ fonction calculant la probabilité d'appartenance d'un point à une gaussienne (sorte de distance) """
    return 1/(sigma*sqrt(2*pi))*exp(-((x-mu))**2/(2*sigma**2))
In [2]:
# a quoi ressemble une gaussienne type
mu = 0.0
sigma = 1.0
G = [mu, sigma]
figure(figsize=(10,10))
plot(linspace(-5, 5, 1000), [gauss(mu, sigma, t) for t in linspace(-5, 5, 1000)])
Out[2]:
[<matplotlib.lines.Line2D at 0x2bd46d0>]

on génére dix valeurs entre 0 et 2 et des valeurs entre 1 et 3 pour obtenir un modèle bimodale

In [7]:
v1 = [rand()*2 for i in xrange(10)]
In [9]:
v2 = [rand()*2 + 1 for i in xrange(10)]

On représente toutes les données par une seule gaussienne

In [18]:
mu = mean(v1+v2)
In [20]:
sigma = var(v1+v2)
In [29]:
figure(figsize=(10,10))
plot(linspace(0, 3, 1000), [gauss(mu, sigma, t) for t in linspace(0, 3, 1000)])
plot(v1+v2, [0]*len(v1+v2), 'ro')
Out[29]:
[<matplotlib.lines.Line2D at 0x3849310>]

Maintenant on va essayer d'apprendre un modèle à 2 gaussiennes. Pour cela on part de la mono-gaussienne précédente et on décale leur centroide de pas grand chose

In [49]:
mu_1 = mu - sigma
sigma_1 = sigma
mu_2 = mu + sigma
sigma_2 = sigma
G1 = [mu_1, sigma_1]
G2 = [mu_2, sigma_2]
In [50]:
# on affiche cette modélisation initiale
figure(figsize=(10,10))
plot(linspace(-1, 4, 1000), [gauss(mu_1, sigma_1, t) for t in linspace(-1, 4, 1000)], 'b')
plot(linspace(-1, 4, 1000), [gauss(mu_2, sigma_2, t) for t in linspace(-1, 4, 1000)], 'g')
plot(v1+v2, [0]*len(v1+v2), 'ro')
Out[50]:
[<matplotlib.lines.Line2D at 0x5080b10>]

pour chaque valeur échantillon on calcul la distance avec chacune des gaussiennes

In [51]:
d_1_2 = {x:(gauss(mu_1, sigma_1, x), gauss(mu_2, sigma_2, x)) for x in v1+v2}

on affecte les points à la gaussienne la plus proche.

In [52]:
x_1 = [x for x in v1+v2 if d_1_2[x][0] < d_1_2[x][1]]
x_2 = [x for x in v1+v2 if d_1_2[x][0] >= d_1_2[x][1]]
In [53]:
print str(len(x_1)) + " " + str(len(x_2))
11 9

On recalcul les moyennes et ecart-types

In [54]:
mu_1 = mean(x_1)
sigma_1 = var(x_1)
mu_2 = mean(x_2)
sigma_2 = var(x_2)

on affiche les nouvelles gaussiennes

In [55]:
figure(figsize=(10,10))
plot(linspace(-1, 4, 1000), [gauss(mu_1, sigma_1, t) for t in linspace(-1, 4, 1000)], 'b')
plot(linspace(-1, 4, 1000), [gauss(mu_2, sigma_2, t) for t in linspace(-1, 4, 1000)], 'g')
plot(v1+v2, [0]*len(v1+v2), 'ro')
Out[55]:
[<matplotlib.lines.Line2D at 0x506ea10>]

Et on recommence ... au moins 9 fois pour voir

In [56]:
figure(figsize=(15,15))

for i in xrange(9):
    d_1_2 = {x:(gauss(mu_1, sigma_1, x), gauss(mu_2, sigma_2, x)) for x in v1+v2}
    x_1 = [x for x in v1+v2 if d_1_2[x][0] >= d_1_2[x][1]]
    x_2 = [x for x in v1+v2 if d_1_2[x][0] < d_1_2[x][1]]
    mu_1 = mean(x_1)
    sigma_1 = var(x_1)
    mu_2 = mean(x_2)
    sigma_2 = var(x_2)
    subplot(3, 3, i+1)
    plot(linspace(-1, 4, 1000), [gauss(mu_1, sigma_1, t) for t in linspace(-1, 4, 1000)], 'b')
    plot(linspace(-1, 4, 1000), [gauss(mu_2, sigma_2, t) for t in linspace(-1, 4, 1000)], 'g')
    plot(v1+v2, [0]*len(v1+v2), 'ro')

On voit qu'on a un problème (normalement il y a des messages d'erreur) Le problème est qu'une des gaussiennes se concentre sur une seule valeur et que donc son sigma vaut zéro, d'où la division par zéro. La solution ? toujours mettre une valeur minimale pour sigma si cette valeur devient trop petite.

In [63]:
# on réinitialise
mu_1 = mu - sigma
sigma_1 = sigma
mu_2 = mu + sigma
sigma_2 = sigma
G1 = [mu_1, sigma_1]
G2 = [mu_2, sigma_2]

# et c'est reparti
figure(figsize=(15,15))

for i in xrange(9):
    d_1_2 = {x:(gauss(mu_1, sigma_1, x), gauss(mu_2, sigma_2, x)) for x in v1+v2}
    x_1 = [x for x in v1+v2 if d_1_2[x][0] >= d_1_2[x][1]]
    x_2 = [x for x in v1+v2 if d_1_2[x][0] < d_1_2[x][1]]
    mu_1 = mean(x_1)
    sigma_1 = var(x_1) 
    sigma_1 = sigma_1 if sigma_1 >= 0.001 else 0.001
    mu_2 = mean(x_2)
    sigma_2 = var(x_2)
    sigma_2 = sigma_2 if sigma_2 >= 0.001 else 0.001
    subplot(3, 3, i+1)
    plot(linspace(-1, 4, 1000), [gauss(mu_1, sigma_1, t) for t in linspace(-1, 4, 1000)], 'b')
    plot(linspace(-1, 4, 1000), [gauss(mu_2, sigma_2, t) for t in linspace(-1, 4, 1000)], 'g')
    plot(v1+v2, [0]*len(v1+v2), 'ro')

Etrange, voyons les valeurs des moyennes et le nombre d'échantillons pour chacune des gaussiennes

In [64]:
print mu_1, sigma_1, len(x_1)
print mu_2, sigma_2, len(x_2)
0.534605884709 0.001 1
1.42277763565 0.347452568374 19

On voit donc qu'une des gaussiennes ne représente qu'une seule valeur alors que l'autre représente toutes les autres. Pas de chance dans le tirage au sort des valeurs

Faisons un deuxième test avec une bimodalité plus franche.

In [67]:
v1 = [rand()*2 for i in xrange(10)]
v2 = [rand()*2 + 3 for i in xrange(10)]
mu = mean(v1+v2)
sigma = var(v1+v2)
figure(figsize=(10,10))
plot(linspace(0, 5, 1000), [gauss(mu, sigma, t) for t in linspace(0, 5, 1000)])
plot(v1+v2, [0]*len(v1+v2), 'ro')
Out[67]:
[<matplotlib.lines.Line2D at 0x3702910>]
In [69]:
# on initialise
mu_1 = mu - sigma
sigma_1 = sigma
mu_2 = mu + sigma
sigma_2 = sigma
G1 = [mu_1, sigma_1]
G2 = [mu_2, sigma_2]

# et c'est reparti
figure(figsize=(15,15))

for i in xrange(9):
    d_1_2 = {x:(gauss(mu_1, sigma_1, x), gauss(mu_2, sigma_2, x)) for x in v1+v2}
    x_1 = [x for x in v1+v2 if d_1_2[x][0] >= d_1_2[x][1]]
    x_2 = [x for x in v1+v2 if d_1_2[x][0] < d_1_2[x][1]]
    mu_1 = mean(x_1)
    sigma_1 = var(x_1) 
    sigma_1 = sigma_1 if sigma_1 >= 0.001 else 0.001
    mu_2 = mean(x_2)
    sigma_2 = var(x_2)
    sigma_2 = sigma_2 if sigma_2 >= 0.001 else 0.001
    subplot(3, 3, i+1)
    plot(linspace(-1, 5, 1000), [gauss(mu_1, sigma_1, t) for t in linspace(-1, 5, 1000)], 'b')
    plot(linspace(-1, 5, 1000), [gauss(mu_2, sigma_2, t) for t in linspace(-1, 5, 1000)], 'g')
    plot(v1+v2, [0]*len(v1+v2), 'ro')

Ici, on voit que dès la première itération chacune des gaussiennes est sur son jeu de données. Les 20 valeurs initiales peuvent venir d'un exemple, de deux exemples ou de 20 exemples, le calcul et les étapes sont les mêmes. On peut donc simplement concaténer les données.