PRML 混合ガウス分布の EM アルゴリズムを R で実装してみた

PRML 9 章の混合ガウス分布の EM アルゴリズムを勉強のために実装してみた。(より本格的な実装と検証は id:n_shuyo さんのEM アルゴリズム実装(勉強用) - Mi manca qualche giovedi`?を参照のこと)。


今回初めての R だったので色々苦労したが、Rは良く出来ていてとても感心した。


真の分布を定義したのち伝承サンプリングでデータを生成し、Eステップ、Mステップを回して収束させた。

# データ生成
xx <- ancestralSampling(1000)

# データを描いてみる
plot(xx);

# K=2 D=2 の混合ガウス分布を真の分布として定義。
start_pi <- list(rnorm(1, 0.5), rnorm(1, 0.5));
start_mu <- list(c(rnorm(1, 10), rnorm(1, 10)), c(rnorm(1, 10), rnorm(1, 10)));
start_sigma <- list(matrix(c(1,0,0,1), 2, 2), matrix(c(1,0,0,1), 2, 2));

# 一度目の EM ステップ
gammaKn = Estep(xx, start_pi, start_mu, start_sigma)
v <- Mstep(xx, gammaKn);

# 30回ほど繰り返すと収束する
for(n in 1:30) {
  gammaKn = Estep(xx, v[[1]], v[[2]], v[[3]]);
  v <- Mstep(xx, gammaKn);
  cat(sprintf("n=%d mu=v%s\n", n, v[[1]]));
}


今回は以下の2つのパッケージを利用して楽をした

全コード

バグを見つけたら教えてください :D

# PRML 9.2 Mixture of Gaussians
#   There are two Gaussians.
library(mvtnorm)
library(RUnit)

responsibility <- function(xn, k, K, pi, mu, sigma) {
  a <- pi[[k]] * dmvnorm(xn, mu[[k]], sigma[[k]]);
  b <- sum(sapply(1:K, function(j) { pi[[j]] * dmvnorm(xn, mu[[j]], sigma[[j]]) }));
  a / b;
}

Estep <- function(xx, pi, mu, sigma) {
  K <- length(mu);
  apply(xx, 1, function(x) {
    sapply(1:K, function(k) {
      responsibility(x, k, K, pi, mu, sigma);
    });
  });
}

nK <- function(xx, k, gammaKn) {
  ret <- 0;
  N <- nrow(xx);
  for(n in 1:N) {
    ret <- ret + gammaKn[k, n];
  }
  ret;
}

muNew <- function(xx, k, gammaKn, nK) {
  N <- nrow(xx);
  sum = c(0, 0);
  for(n in 1:N) {
    sum <- sum + gammaKn[k, n] * xx[n,];
  }
  sum / nK;
}

sigmaNew <- function(xx, k, gammaKn, muKNew, nK) {
  N <- nrow(xx);
  sum = c(0, 0, 0, 0);
  for(n in 1:N) {
    sum <- sum + gammaKn[k, n] * ((xx[n,] - muKNew) %*% t(xx[n,] - muKNew));
  }
  matrix(sum / nK, ncol=2);
}

piNew <- function(xx, k, gammaKn, nK) {
  nK / length(xx) * length(xx[1,]);
}

Mstep <- function(xx, gammaKn) {
  K <- length(xx[1,]);
  nKList <- lapply(1:K, function(k) { nK(xx, k, gammaKn); });
  piNext <- lapply(1:K, function(k) { piNew(xx, k, gammaKn, nKList[[k]]); });
  muNext <- lapply(1:K, function(k) { muNew(xx, k, gammaKn, nKList[[k]]); });
  sigmaNext <- lapply(1:K, function(k) { sigmaNew(xx, k, gammaKn, muNext[[k]], nKList[[k]]); });
  list(piNext, muNext, sigmaNext);
}

input.data <- function() {
  pi <- list(0.7, 0.3);
  mu <- list(c(6, 7), c(1, 1));
  sigma <- list(matrix(c(7,0,0,7), 2, 2), matrix(c(10,3,3,10), 2, 2));
  xx <- matrix(c(1, 2, 3, 4, 5, 6),ncol=2, byrow=TRUE);
  list(pi, mu, sigma, xx);
}

## Unit Tests
test.nK <- function() {
  input <- input.data();
  pi <- input[[1]];
  mu <- input[[2]];
  sigma <- input[[3]];
  xx <- input[[4]];
  gammaKn <- Estep(xx, pi, mu, sigma);
  checkEqualsNumeric(nK(xx, 1, gammaKn), 1.613336, tolerance = 0.0001);
}

test.Estep <- function() {
  input <- input.data();
  pi <- input[[1]];
  mu <- input[[2]];
  sigma <- input[[3]];
  xx <- input[[4]];
  gammaNk <- Estep(xx, pi, mu, sigma);
  checkEqualsNumeric(gammaNk, matrix(c(0.08630052, 0.5957016, 0.9313342,
                                       0.91369948, 0.4042984, 0.0686658), nrow=2, byrow=T), tolerance=0.0001);
}

test.Mstep <- function() {
  input <- input.data();
  pi <- input[[1]];
  mu <- input[[2]];
  sigma <- input[[3]];
  xx <- input[[4]];
  gammaNk <- Estep(xx, pi, mu, sigma);
  checkEquals(Mstep(xx, gammaNk),
              list(list(0.5377787, 0.4622212), # pi
                   list(c(4.047560, 5.047560), c(1.781199, 2.781199)), # mu
                   list(matrix(c(1.425674, 1.425674,
                                 1.425674, 1.425674), byrow=T, ncol=2),
                        matrix(c(1.348276, 1.348276,
                                 1.348276, 1.348276), byrow=T, ncol=2))),
              tolerance=0.0001);
}

test.muKNew <- function() {
  input <- input.data();
  pi <- input[[1]];
  mu <- input[[2]];
  sigma <- input[[3]];
  xx <- input[[4]];
  gammaKn <- Estep(xx, pi, mu, sigma);
  muKNew <- muNew(xx, 1, gammaKn, nK(xx, 1, gammaKn));
  checkEqualsNumeric(muKNew, c(4.047560, 5.047560), tolerance=0.0001);
}

test.sigmaNew <- function() {
  input <- input.data();
  pi <- input[[1]];
  mu <- input[[2]];
  sigma <- input[[3]];
  xx <- input[[4]];
  gammaKn <- Estep(xx, pi, mu, sigma);
  muKNew <- muNew(xx, 1, gammaKn, nK(xx, 1, gammaKn));
  checkEqualsNumeric(sigmaNew(xx, 1, gammaKn, muKNew, nK(xx, 1, gammaKn)),
                     matrix(c(1.425674, 1.425674,
                              1.425674, 1.425674), byrow=T, ncol=2), tolerance=0.0001);
## sigmaNew
## > (gammaNk[1, 1] * (c(1, 2) - muKNew)  %*% t((c(1, 2) - muKNew)) +
## + gammaNk[1, 2] * (c(3, 4) - muKNew)  %*% t((c(3, 4) - muKNew)) +
## + gammaNk[1, 3] * (c(5, 6) - muKNew)  %*% t((c(5, 6) - muKNew))) / 1.613336
##          [,1]     [,2]
## [1,] 1.425674 1.425674
## [2,] 1.425674 1.425674
## >
}

test.piKNew <- function() {
  input <- input.data();
  pi <- input[[1]];
  mu <- input[[2]];
  sigma <- input[[3]];
  xx <- input[[4]];
  gammaKn <- Estep(xx, pi, mu, sigma);
  k <- 1;
  checkEqualsNumeric(piNew(xx, k, gammaKn, nK(xx, k, gammaKn)), 0.5377787, tolerance=0.0001);
}

# runTestFile("./em.R")