The Data Science Lab
Gaussian Naive Bayes Classification Using the scikit Library
GaussianNB(*, priors=None, var_smoothing=1e-09)
When working with scikit, you'll spend most of your time reading the documentation and trying to figure out what each model parameter does. The priors parameter allows you to specify the initial probabilities of each target class. For example, for the demo training data there are 60 items of each of the three classes/species, and so the GaussianNB algorithm uses prior probabilities of 0.3333 for each class. If you passed [0.25, 0.50, 0.25] to the GaussianNB constructor, the algorithm would use those values for initial probabilities. The priors parameter is typically used when you have limited data in one target class and you want to specify equal initial probabilities.
The var_smoothing parameter adds a value to all predictor variances. The value added is a proportion of the largest predictor variance. The idea is that predictors with small variances can overwhelm predictors with large variances (because a predictor's variance appears in a fraction denominator during calculations). Adding a var_smoothing factor can artificially increase small variances to prevent a small-variance predictor from dominating the overall pseudo-probability calculations. In practice, the var_smoothing parameter is rarely used.
Evaluating the Trained Model
The demo computes the accuracy of the trained model like so:
# 3. evaluate model
acc_train = model.score(x_train, y_train)
print("Accuracy on train data = %0.4f " % acc_train)
acc_test = model.score(x_test, y_test)
print("Accuracy on test data = %0.4f " % acc_test)
The score() function computes a simple accuracy, which is just the number of correct predictions divided by the total number of predictions. However, for classification problems you usually want additional evaluation metrics to show how the model predicts for different target labels. For example, if a 100-item training dataset had 95 Kama seeds, 3 Rosa seeds and 2 Canadian seeds, then predicting Kama for any input would score 95 percent accuracy.
The scikit library has many ways to evaluate a trained classification prediction model. A good technique is to compute and display a confusion matrix:
# 3b. confusion matrix
from sklearn.metrics import confusion_matrix
y_predicteds = model.predict(x_test)
cm = confusion_matrix(y_test, y_predicteds)
print("Confusion matrix for test data: ")
# print(cm) # raw
show_confusion(cm) # formatted
For the demo training data, the output of a raw confusion matrix would be:
[[6 0 4]
[5 5 0]
[1 0 9]]
A raw scikit confusion matrix is difficult to interpret so I usually implement a program-defined function called show_confusion() that adds basic labels. The output of show_confusion() is:
actual 0: 6 0 4
actual 1: 5 5 0
actual 2: 1 0 9
------------
predicted 0 1 2
The formatted output is much easier to interpret than the raw output. You can find the source code for the show_confusion() function in Listing 1.
Using the Trained Model
The demo program uses the model to predict the course type of a new, previously unseen dummy wheat seed:
# 4. use model
print("Predicting species all 0.2 predictors: ")
X = np.array([[0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2]],
dtype=np.float32)
print(X)
probs = model.predict_proba(X)
print("Prediction probs: ")
print(probs)
Notice the double square brackets on the x-input. The predict_proba() function expects a matrix rather than a vector. Because the GaussianNB model was trained using normalized data, the seven input values must be normalized in the same way, by dividing raw values by (25, 20, 1, 10, 10, 10, 10) respectively.
The return result from the predict_proba() function ("probabilities array") for the demo data is [[0.0033, 0.0000, 0.9967 ]]. The result has only one row because only one input was supplied. The three values in the row are the pseudo-probabilities of class 0, 1, and 2 respectively.
The demo program concludes with:
predicted = model.predict(X)
print("Predicted class: ")
print(predicted)
# 5. TODO: save model using pickle
print("End demo ")
if __name__ == "__main__":
main()
The predict() method returns the predicted class, 0, 1, 2, rather than pseudo-probabilities.
Saving the Trained Model
The demo doesn't save the trained model. When using scikit, the most common way to save a trained naive Bayes classifier model is to use the pickle library ("pickle" means to preserve in English, as in "pickled cucumbers"). For example:
import pickle
print("Saving Gaussian naive Bayes model ")
path = ".\\Models\\wheat_gnb_model.sav"
pickle.dump(model, open(path, "wb"))
This code assumes there is a directory named Models. The saved model could be loaded and used from another program like so:
# predict for unknown wheat seed
X = np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]],
dtype=np.float32)
with open(path, 'rb') as f:
loaded_model = pickle.load(f)
pa = loaded_model.predict_proba(x)
print(pa) # pseudo-probabilities
There are several other ways to save and load a trained scikit model, but using the pickle library is simplest.
Wrapping Up
The main advantage of using Gaussian naive Bayes classification compared to other techniques like decision trees or neural networks is that you don't have to fine-tune model parameters. The main disadvantage of Gaussian NB is that the technique assumes all predictor variables are Gaussian distributed, which is often not true, and therefore the technique is typically not as powerful as a well-tuned decision tree or neural network model. Another disadvantage is that Gaussian NB requires all predictors to be numeric and so it can't handle categorical predictors (like color = "red", "blue", or "green") or Boolean predictors (like a person's sex = male or female).
Because Gaussian naive Bayes classification is so easy to use, my colleagues and I often use the technique to establish a baseline accuracy. Then a more powerful model, usually a neural network classifier, can be constructed with a rough idea of how accurate the model should be.
About the Author
Dr. James McCaffrey works for Microsoft Research in Redmond, Wash. He has worked on several Microsoft products including Azure and Bing. James can be reached at [email protected].