[python] Save classifier to disk in scikit-learn

How do I save a trained Naive Bayes classifier to disk and use it to predict data?

I have the following sample program from the scikit-learn website:

from sklearn import datasets
iris = datasets.load_iris()
from sklearn.naive_bayes import GaussianNB
gnb = GaussianNB()
y_pred = gnb.fit(iris.data, iris.target).predict(iris.data)
print "Number of mislabeled points : %d" % (iris.target != y_pred).sum()

This question is related to python machine-learning scikit-learn classification

The answer is


What you are looking for is called Model persistence in sklearn words and it is documented in introduction and in model persistence sections.

So you have initialized your classifier and trained it for a long time with

clf = some.classifier()
clf.fit(X, y)

After this you have two options:

1) Using Pickle

import pickle
# now you can save it to a file
with open('filename.pkl', 'wb') as f:
    pickle.dump(clf, f)

# and later you can load it
with open('filename.pkl', 'rb') as f:
    clf = pickle.load(f)

2) Using Joblib

from sklearn.externals import joblib
# now you can save it to a file
joblib.dump(clf, 'filename.pkl') 
# and later you can load it
clf = joblib.load('filename.pkl')

One more time it is helpful to read the above-mentioned links


You can also use joblib.dump and joblib.load which is much more efficient at handling numerical arrays than the default python pickler.

Joblib is included in scikit-learn:

>>> import joblib
>>> from sklearn.datasets import load_digits
>>> from sklearn.linear_model import SGDClassifier

>>> digits = load_digits()
>>> clf = SGDClassifier().fit(digits.data, digits.target)
>>> clf.score(digits.data, digits.target)  # evaluate training error
0.9526989426822482

>>> filename = '/tmp/digits_classifier.joblib.pkl'
>>> _ = joblib.dump(clf, filename, compress=9)

>>> clf2 = joblib.load(filename)
>>> clf2
SGDClassifier(alpha=0.0001, class_weight=None, epsilon=0.1, eta0=0.0,
       fit_intercept=True, learning_rate='optimal', loss='hinge', n_iter=5,
       n_jobs=1, penalty='l2', power_t=0.5, rho=0.85, seed=0,
       shuffle=False, verbose=0, warm_start=False)
>>> clf2.score(digits.data, digits.target)
0.9526989426822482

Edit: in Python 3.8+ it's now possible to use pickle for efficient pickling of object with large numerical arrays as attributes if you use pickle protocol 5 (which is not the default).


sklearn.externals.joblib has been deprecated since 0.21 and will be removed in v0.23:

/usr/local/lib/python3.7/site-packages/sklearn/externals/joblib/init.py:15: FutureWarning: sklearn.externals.joblib is deprecated in 0.21 and will be removed in 0.23. Please import this functionality directly from joblib, which can be installed with: pip install joblib. If this warning is raised when loading pickled models, you may need to re-serialize those models with scikit-learn 0.21+.
warnings.warn(msg, category=FutureWarning)


Therefore, you need to install joblib:

pip install joblib

and finally write the model to disk:

import joblib
from sklearn.datasets import load_digits
from sklearn.linear_model import SGDClassifier


digits = load_digits()
clf = SGDClassifier().fit(digits.data, digits.target)

with open('myClassifier.joblib.pkl', 'wb') as f:
    joblib.dump(clf, f, compress=9)

Now in order to read the dumped file all you need to run is:

with open('myClassifier.joblib.pkl', 'rb') as f:
    my_clf = joblib.load(f)

In many cases, particularly with text classification it is not enough just to store the classifier but you'll need to store the vectorizer as well so that you can vectorize your input in future.

import pickle
with open('model.pkl', 'wb') as fout:
  pickle.dump((vectorizer, clf), fout)

future use case:

with open('model.pkl', 'rb') as fin:
  vectorizer, clf = pickle.load(fin)

X_new = vectorizer.transform(new_samples)
X_new_preds = clf.predict(X_new)

Before dumping the vectorizer, one can delete the stop_words_ property of vectorizer by:

vectorizer.stop_words_ = None

to make dumping more efficient. Also if your classifier parameters is sparse (as in most text classification examples) you can convert the parameters from dense to sparse which will make a huge difference in terms of memory consumption, loading and dumping. Sparsify the model by:

clf.sparsify()

Which will automatically work for SGDClassifier but in case you know your model is sparse (lots of zeros in clf.coef_) then you can manually convert clf.coef_ into a csr scipy sparse matrix by:

clf.coef_ = scipy.sparse.csr_matrix(clf.coef_)

and then you can store it more efficiently.


sklearn estimators implement methods to make it easy for you to save relevant trained properties of an estimator. Some estimators implement __getstate__ methods themselves, but others, like the GMM just use the base implementation which simply saves the objects inner dictionary:

def __getstate__(self):
    try:
        state = super(BaseEstimator, self).__getstate__()
    except AttributeError:
        state = self.__dict__.copy()

    if type(self).__module__.startswith('sklearn.'):
        return dict(state.items(), _sklearn_version=__version__)
    else:
        return state

The recommended method to save your model to disc is to use the pickle module:

from sklearn import datasets
from sklearn.svm import SVC
iris = datasets.load_iris()
X = iris.data[:100, :2]
y = iris.target[:100]
model = SVC()
model.fit(X,y)
import pickle
with open('mymodel','wb') as f:
    pickle.dump(model,f)

However, you should save additional data so you can retrain your model in the future, or suffer dire consequences (such as being locked into an old version of sklearn).

From the documentation:

In order to rebuild a similar model with future versions of scikit-learn, additional metadata should be saved along the pickled model:

The training data, e.g. a reference to a immutable snapshot

The python source code used to generate the model

The versions of scikit-learn and its dependencies

The cross validation score obtained on the training data

This is especially true for Ensemble estimators that rely on the tree.pyx module written in Cython(such as IsolationForest), since it creates a coupling to the implementation, which is not guaranteed to be stable between versions of sklearn. It has seen backwards incompatible changes in the past.

If your models become very large and loading becomes a nuisance, you can also use the more efficient joblib. From the documentation:

In the specific case of the scikit, it may be more interesting to use joblib’s replacement of pickle (joblib.dump & joblib.load), which is more efficient on objects that carry large numpy arrays internally as is often the case for fitted scikit-learn estimators, but can only pickle to the disk and not to a string:


Examples related to python

programming a servo thru a barometer Is there a way to view two blocks of code from the same file simultaneously in Sublime Text? python variable NameError Why my regexp for hyphenated words doesn't work? Comparing a variable with a string python not working when redirecting from bash script is it possible to add colors to python output? Get Public URL for File - Google Cloud Storage - App Engine (Python) Real time face detection OpenCV, Python xlrd.biffh.XLRDError: Excel xlsx file; not supported Could not load dynamic library 'cudart64_101.dll' on tensorflow CPU-only installation

Examples related to machine-learning

Error in Python script "Expected 2D array, got 1D array instead:"? How to predict input image using trained model in Keras? What is the role of "Flatten" in Keras? How to concatenate two layers in keras? How to save final model using keras? scikit-learn random state in splitting dataset Why binary_crossentropy and categorical_crossentropy give different performances for the same problem? What is the meaning of the word logits in TensorFlow? Can anyone explain me StandardScaler? Can Keras with Tensorflow backend be forced to use CPU or GPU at will?

Examples related to scikit-learn

LabelEncoder: TypeError: '>' not supported between instances of 'float' and 'str' UndefinedMetricWarning: F-score is ill-defined and being set to 0.0 in labels with no predicted samples scikit-learn random state in splitting dataset LogisticRegression: Unknown label type: 'continuous' using sklearn in python Can anyone explain me StandardScaler? ImportError: No module named model_selection How to split data into 3 sets (train, validation and test)? How to convert a Scikit-learn dataset to a Pandas dataset? Accuracy Score ValueError: Can't Handle mix of binary and continuous target How can I plot a confusion matrix?

Examples related to classification

FailedPreconditionError: Attempting to use uninitialized in Tensorflow Scikit-learn train_test_split with indices Scikit-learn: How to obtain True Positive, True Negative, False Positive and False Negative What are advantages of Artificial Neural Networks over Support Vector Machines? Save classifier to disk in scikit-learn A simple explanation of Naive Bayes Classification Difference between classification and clustering in data mining?