Here is a simple example using Tensorflow 2.0 SavedModel format (which is the recommended format, according to the docs) for a simple MNIST dataset classifier, using Keras functional API without too much fancy going on:
# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()
# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)
# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# Train
model.fit(x_train, y_train, epochs=3)
# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)
# ... possibly another python program
# Reload model
loaded_model = tf.keras.models.load_model(export_path)
# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step
# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))
# Show results
print(np.argmax(prediction['graph_output'])) # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary) # prints the image
What is serving_default
?
It's the name of the signature def of the tag you selected (in this case, the default serve
tag was selected). Also, here explains how to find the tag's and signatures of a model using saved_model_cli
.
Disclaimers
This is just a basic example if you just want to get it up and running, but is by no means a complete answer - maybe I can update it in the future. I just wanted to give a simple example using the SavedModel
in TF 2.0 because I haven't seen one, even this simple, anywhere.
@Tom's answer is a SavedModel example, but it will not work on Tensorflow 2.0, because unfortunately there are some breaking changes.
@Vishnuvardhan Janapati's answer says TF 2.0, but it's not for SavedModel format.