The Data Science Lab

Gaussian Mixture Model Data Clustering from Scratch Using C#

Loading the Data
The demo program begins by loading the dummy data into memory:

double[][] X = new double[10][];  // 10x2
X[0] = new double[] { 0.01, 0.10 };
X[1] = new double[] { 0.02, 0.09 };
X[2] = new double[] { 0.03, 0.10 };
X[3] = new double[] { 0.04, 0.06 };
X[4] = new double[] { 0.05, 0.06 };
X[5] = new double[] { 0.06, 0.05 };
X[6] = new double[] { 0.07, 0.05 };
X[7] = new double[] { 0.08, 0.01 };
X[8] = new double[] { 0.09, 0.02 };
X[9] = new double[] { 0.10, 0.01 };

For simplicity, the data is hard-coded. In a non-demo scenario, you'll probably want to read the data from a text file into memory using the helper MatLoad() method like so:

string dataFile = "..\\..\\..\\Data\\dummy_10.txt";
double[][] X = MatLoad(dataFile,
  new int[] { 0, 1 }, ',', "#");

The parameters mean to use columns 0 and 1 of the text file, which is comma-delimited, and where the "#" character at the start of a line indicates a comment. The demo program displays the loaded data using the MatShow() function:

Console.WriteLine("Data: ");
MatShow(X, 2, 6);  // 2 decimals, 6 wide

Like most clustering techniques, GMM clustering works best with strictly numeric data that has been normalized so that the magnitudes of all the columns are about the same, typically between 0 and 1, or between -1 and +1. The demo data is normalized so that all the data is between 0.01 and 0.10 because there are so few data points.

Clustering the Data
The GMM clustering object is created and used like so:

int k = 3;  // number of clusters
int mi = 100;  // max iterations
Console.WriteLine("Creating C# scratch" +
  " GMM model k=3, maxIter=100 ");
GMM gmm = new GMM(k, mi, seed: 9);
Console.WriteLine("Clustering ");
int numIter = gmm.Cluster(X);
Console.WriteLine("Done. Used " +
  numIter + " iterations ");

Because GMM clustering is an iterative process, you should specify a maximum number of iterations. This value must be determined by experimentation. The GMM constructor accepts k, the number of clusters to construct, the maximum iterations and a seed value that is used by an internal Random object. The seed value of 9 is used only because it gives representative results. The Random object is used by a helper Shuffle() method, which is used by a helper Select() method, which is used by the primary Cluster() method to initialize the cluster means to k randomly selected data items. Whew!


The data is clustered by calling the Cluster() method. The clustering pseudo-probabilities information is stored internally as a member vector field named probs, which can be accessed by the PredictProbs() method. The Cluster() explicit return value is the number of iterations actually performed in the EM algorithm until the pseudo-probabilities stopped changing very much. Implementing the EM stopping criteria is the primary customization point in the demo program, as you'll see shortly.

Displaying the Clustering Results
The clustering results are displayed like so:

double[][] probs = gmm.PredictProbs(X);
int[] labels = gmm.PredictLabels(X);
Console.WriteLine("Clustering results: ");
for (int i = 0; i < probs.Length; ++i) {
  VecShow(X[i], 2, 6, false);
  Console.Write(" | ");
  VecShow(probs[i], 4, 9, false);
  Console.Write(" | ");
  Console.WriteLine(labels[i]);
}

The PredictProbs() method returns the pseudo-probabilities as an array-of-arrays style matrix. The PredictLabels() method return the clustering as an integer array of 0-based cluster IDs. A cluster ID is just the index of the largest pseudo-probability. The output looks like:

0.01  0.10 |  0.0000   1.0000   0.0000 | 1
0.02  0.09 |  0.0000   1.0000   0.0000 | 1
0.03  0.10 |  0.0000   0.0000   1.0000 | 2
0.04  0.06 |  1.0000   0.0000   0.0000 | 0
0.05  0.06 |  0.0087   0.9296   0.0617 | 1
. . .

The clustering information can be displayed in several ways. For example, instead of displaying by data item, you can display by cluster ID:

for (int cid = 0; cid < 3; ++cid) {
  Console.WriteLine("cid = " + cid);
  for (int i = 0; i < X.Length; ++i) {
    if (labels[i] == cid) {
      for (int j = 0; j < X[i].Length; ++j)
        Console.Write(X[i][j].ToString("F2").PadLeft(6));
      Console.WriteLine("");
    }
  }
}

The output would look like:

cid = 0
  0.04  0.06
  0.07  0.05

cid = 1
  0.01  0.10
  0.02  0.09
. . .

Instead of displaying all clustering information, it's possible to show the clustering pseudo-probabilities for a specific data item:

double[] x = new double[] { 0.05, 0.06 };
double[] p = gmm.PredictProbs(x);
Console.WriteLine("cluster pseudo-probs for" +
  " (0.05, 0.06) = ")
VecShow(p, 4, 9, true);  // add newline

The PredictProbs() method is overloaded to accept either multiple data items as an array-of-arrays matrix, or a single data item as a vector.

The demo program concludes by displaying the means, covariance matrices, and coefficients using these statements:

Console.WriteLine("means: ");
MatShow(gmm.means, 4, 9);
Console.WriteLine("covariances: ");
for (int cid = 0; cid < 3; ++cid)
{
  MatShow(gmm.covars[cid], 4, 9);
  Console.WriteLine("");
}
Console.WriteLine("coefficients: ");
VecShow(gmm.coefs, 4, 9, true);

The MatShow() and VecShow() helper functions make the Main() method a bit cleaner at the expense of yet more helpers.

Customizing the EM Stopping Condition
The demo program implementation stops iterating the expectation-maximization loop using this code:

int iter;
for (iter = 0; iter < this.maxIter; ++iter) {
  double oldMeanLogLike = MeanLogLikelihood();
  this.ExpectStep(X);  // update the probs
  double newMeanLogLike = MeanLogLikelihood();
  if (iter > this.maxIter / 2 && 
    Math.Abs(newMeanLogLike - oldMeanLogLike) < 1.0e-3) {
     break;
  }
  this.MaximStep(X);  // update coefs, means, covars
}

The EM loop stops when the maximum number of iterations has been reached, or if the difference between the average log-likelihood of the pseudo-probabilities is less than 0.001 and the number of iterations is at least half of the maximum number of iterations.

The average log-likelihood is computed by the MeanLogLikelihood() helper method. Suppose the pseudo-probabilities for one data item are (0.10, 0.85, 0.05). The log-likelihood for this item is log(0.10) + log(0.85) + log(0.15) = -2.30 + -0.16 + -2.99 = -5.46. For multiple data items, the average log likelihood is the average of the individual log-likelihood values.

The demo GMM clustering implementation uses minimal change in average log-likelihood as the EM stopping condition mostly because that's the stopping condition used by the scikit-learn library GMM module. In my opinion, this stopping approach is a bit strange because a.) you have to deal with the possibility of log(0), which is negative infinity, and b.) you have to deal with the order of arguments because log-likelihood(x,y,z) = log-likelihood(z,y,x) = log-likelihood(y,x,z) = etc.

There are many other stopping conditions you can implement. One approach I've used is to just directly check to see if the pseudo-probabilities have not changed more than some small epsilon value.

Wrapping Up
It's important to remember that data clustering is an exploratory process. There is no correct clustering of any dataset. In most data clustering scenarios, there should be a human-in-the-loop where human expertise is used to examine clustering results to see if there are any interesting patterns.

Among my colleagues, there are mixed opinions about Gaussian mixture model clustering. Some of my colleagues believe that GMM clustering is arguably the most powerful clustering algorithm. On the other hand, some of my colleagues believe that GMM clustering is an example of an overly complex research solution in search of a problem. I have an intermediate opinion. I think that different clustering techniques -- k-means, density-based, spectral and GMM -- reveal different patterns and that in most scenarios it's not possible to know beforehand which clustering technique is most applicable.


About the Author

Dr. James McCaffrey works for Microsoft Research in Redmond, Wash. He has worked on several Microsoft products including Azure and Bing. James can be reached at [email protected].

comments powered by Disqus

Featured

Subscribe on YouTube