[python] Tensorflow: how to save/restore a model?

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?

This question is related to python tensorflow machine-learning model

The answer is


Wherever you want to save the model,

self.saver = tf.train.Saver()
with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            ...
            self.saver.save(sess, filename)

Make sure, all your tf.Variable have names, because you may want to restore them later using their names. And where you want to predict,

saver = tf.train.import_meta_graph(filename)
name = 'name given when you saved the file' 
with tf.Session() as sess:
      saver.restore(sess, name)
      print(sess.run('W1:0')) #example to retrieve by variable name

Make sure that saver runs inside the corresponding session. Remember that, if you use the tf.train.latest_checkpoint('./'), then only the latest check point will be used.


According to the new Tensorflow version, tf.train.Checkpoint is the preferable way of saving and restoring a model:

Checkpoint.save and Checkpoint.restore write and read object-based checkpoints, in contrast to tf.train.Saver which writes and reads variable.name based checkpoints. Object-based checkpointing saves a graph of dependencies between Python objects (Layers, Optimizers, Variables, etc.) with named edges, and this graph is used to match variables when restoring a checkpoint. It can be more robust to changes in the Python program, and helps to support restore-on-create for variables when executing eagerly. Prefer tf.train.Checkpoint over tf.train.Saver for new code.

Here is an example:

import tensorflow as tf
import os

tf.enable_eager_execution()

checkpoint_directory = "/tmp/training_checkpoints"
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")

checkpoint = tf.train.Checkpoint(optimizer=optimizer, model=model)
status = checkpoint.restore(tf.train.latest_checkpoint(checkpoint_directory))
for _ in range(num_training_steps):
  optimizer.minimize( ... )  # Variables will be restored on creation.
status.assert_consumed()  # Optional sanity checks.
checkpoint.save(file_prefix=checkpoint_prefix)

More information and example here.


tf.keras Model saving with TF2.0

I see great answers for saving models using TF1.x. I want to provide couple of more pointers in saving tensorflow.keras models which is a little complicated as there are many ways to save a model.

Here I am providing an example of saving a tensorflow.keras model to model_path folder under current directory. This works well with most recent tensorflow (TF2.0). I will update this description if there is any change in near future.

Saving and loading entire model

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

#import data
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# create a model
def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
# compile the model
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
  return model

# Create a basic model instance
model=create_model()

model.fit(x_train, y_train, epochs=1)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save entire model to a HDF5 file
model.save('./model_path/my_model.h5')

# Recreate the exact same model, including weights and optimizer.
new_model = keras.models.load_model('./model_path/my_model.h5')
loss, acc = new_model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Saving and loading model Weights only

If you are interested in saving model weights only and then load weights to restore the model, then

model.fit(x_train, y_train, epochs=5)
loss, acc = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

# Save the weights
model.save_weights('./checkpoints/my_checkpoint')

# Restore the weights
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')

loss,acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

Saving and restoring using keras checkpoint callback

# include the epoch in the file name. (uses `str.format`)
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path, verbose=1, save_weights_only=True,
    # Save weights, every 5-epochs.
    period=5)

model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images, train_labels,
          epochs = 50, callbacks = [cp_callback],
          validation_data = (test_images,test_labels),
          verbose=0)

latest = tf.train.latest_checkpoint(checkpoint_dir)

new_model = create_model()
new_model.load_weights(latest)
loss, acc = new_model.evaluate(test_images, test_labels)
print("Restored model, accuracy: {:5.2f}%".format(100*acc))

saving model with custom metrics

import tensorflow as tf
from tensorflow import keras
mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0

# Custom Loss1 (for example) 
@tf.function() 
def customLoss1(yTrue,yPred):
  return tf.reduce_mean(yTrue-yPred) 

# Custom Loss2 (for example) 
@tf.function() 
def customLoss2(yTrue, yPred):
  return tf.reduce_mean(tf.square(tf.subtract(yTrue,yPred))) 

def create_model():
  model = tf.keras.models.Sequential([
    tf.keras.layers.Flatten(input_shape=(28, 28)),
    tf.keras.layers.Dense(512, activation=tf.nn.relu),  
    tf.keras.layers.Dropout(0.2),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax)
    ])
  model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy', customLoss1, customLoss2])
  return model

# Create a basic model instance
model=create_model()

# Fit and evaluate model 
model.fit(x_train, y_train, epochs=1)
loss, acc,loss1, loss2 = model.evaluate(x_test, y_test,verbose=1)
print("Original model, accuracy: {:5.2f}%".format(100*acc))

model.save("./model.h5")

new_model=tf.keras.models.load_model("./model.h5",custom_objects={'customLoss1':customLoss1,'customLoss2':customLoss2})

Saving keras model with custom ops

When we have custom ops as in the following case (tf.tile), we need to create a function and wrap with a Lambda layer. Otherwise, model cannot be saved.

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras import Model

def my_fun(a):
  out = tf.tile(a, (1, tf.shape(a)[0]))
  return out

a = Input(shape=(10,))
#out = tf.tile(a, (1, tf.shape(a)[0]))
out = Lambda(lambda x : my_fun(x))(a)
model = Model(a, out)

x = np.zeros((50,10), dtype=np.float32)
print(model(x).numpy())

model.save('my_model.h5')

#load the model
new_model=tf.keras.models.load_model("my_model.h5")

I think I have covered a few of the many ways of saving tf.keras model. However, there are many other ways. Please comment below if you see your use case is not covered above. Thanks!


You can also check out examples in TensorFlow/skflow, which offers save and restore methods that can help you easily manage your models. It has parameters that you can also control how frequently you want to back up your model.


All the answers here are great, but I want to add two things.

First, to elaborate on @user7505159's answer, the "./" can be important to add to the beginning of the file name that you are restoring.

For example, you can save a graph with no "./" in the file name like so:

# Some graph defined up here with specific names

saver = tf.train.Saver()
save_file = 'model.ckpt'

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_file)

But in order to restore the graph, you may need to prepend a "./" to the file_name:

# Same graph defined up here

saver = tf.train.Saver()
save_file = './' + 'model.ckpt' # String addition used for emphasis

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, save_file)

You will not always need the "./", but it can cause problems depending on your environment and version of TensorFlow.

It also want to mention that the sess.run(tf.global_variables_initializer()) can be important before restoring the session.

If you are receiving an error regarding uninitialized variables when trying to restore a saved session, make sure you include sess.run(tf.global_variables_initializer()) before the saver.restore(sess, save_file) line. It can save you a headache.


You can also take this easier way.

Step 1: initialize all your variables

W1 = tf.Variable(tf.truncated_normal([6, 6, 1, K], stddev=0.1), name="W1")
B1 = tf.Variable(tf.constant(0.1, tf.float32, [K]), name="B1")

Similarly, W2, B2, W3, .....

Step 2: save the session inside model Saver and save it

model_saver = tf.train.Saver()

# Train the model and save it in the end
model_saver.save(session, "saved_models/CNN_New.ckpt")

Step 3: restore the model

with tf.Session(graph=graph_cnn) as session:
    model_saver.restore(session, "saved_models/CNN_New.ckpt")
    print("Model restored.") 
    print('Initialized')

Step 4: check your variable

W1 = session.run(W1)
print(W1)

While running in different python instance, use

with tf.Session() as sess:
    # Restore latest checkpoint
    saver.restore(sess, tf.train.latest_checkpoint('saved_model/.'))

    # Initalize the variables
    sess.run(tf.global_variables_initializer())

    # Get default graph (supply your custom graph if you have one)
    graph = tf.get_default_graph()

    # It will give tensor object
    W1 = graph.get_tensor_by_name('W1:0')

    # To get the value (numpy array)
    W1_value = session.run(W1)

I'm on Version:

tensorflow (1.13.1)
tensorflow-gpu (1.13.1)

Simple way is

Save:

model.save("model.h5")

Restore:

model = tf.keras.models.load_model("model.h5")

If it is an internally saved model, you just specify a restorer for all variables as

restorer = tf.train.Saver(tf.all_variables())

and use it to restore variables in a current session:

restorer.restore(self._sess, model_file)

For the external model you need to specify the mapping from the its variable names to your variable names. You can view the model variable names using the command

python /path/to/tensorflow/tensorflow/python/tools/inspect_checkpoint.py --file_name=/path/to/pretrained_model/model.ckpt

The inspect_checkpoint.py script can be found in './tensorflow/python/tools' folder of the Tensorflow source.

To specify the mapping, you can use my Tensorflow-Worklab, which contains a set of classes and scripts to train and retrain different models. It includes an example of retraining ResNet models, located here


As described in issue 6255:

use '**./**model_name.ckpt'
saver.restore(sess,'./my_model_final.ckpt')

instead of

saver.restore('my_model_final.ckpt')

In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according to https://www.tensorflow.org/programmers_guide/meta_graph.

Save the model

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1')
w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2')
tf.add_to_collection('vars', w1)
tf.add_to_collection('vars', w2)
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver.save(sess, 'my-model')
# `save` method will call `export_meta_graph` implicitly.
# you will get saved graph files:my-model.meta

Restore the model

sess = tf.Session()
new_saver = tf.train.import_meta_graph('my-model.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
all_vars = tf.get_collection('vars')
for v in all_vars:
    v_ = sess.run(v)
    print(v_)

I am improving my answer to add more details for saving and restoring models.

In(and after) Tensorflow version 0.11:

Save the model:

import tensorflow as tf

#Prepare to feed input, i.e. feed_dict and placeholders
w1 = tf.placeholder("float", name="w1")
w2 = tf.placeholder("float", name="w2")
b1= tf.Variable(2.0,name="bias")
feed_dict ={w1:4,w2:8}

#Define a test operation that we will restore
w3 = tf.add(w1,w2)
w4 = tf.multiply(w3,b1,name="op_to_restore")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

#Create a saver object which will save all the variables
saver = tf.train.Saver()

#Run the operation by feeding input
print sess.run(w4,feed_dict)
#Prints 24 which is sum of (w1+w2)*b1 

#Now, save the graph
saver.save(sess, 'my_test_model',global_step=1000)

Restore the model:

import tensorflow as tf

sess=tf.Session()    
#First let's load meta graph and restore weights
saver = tf.train.import_meta_graph('my_test_model-1000.meta')
saver.restore(sess,tf.train.latest_checkpoint('./'))


# Access saved Variables directly
print(sess.run('bias:0'))
# This will print 2, which is the value of bias that we saved


# Now, let's access and create placeholders variables and
# create feed-dict to feed new data

graph = tf.get_default_graph()
w1 = graph.get_tensor_by_name("w1:0")
w2 = graph.get_tensor_by_name("w2:0")
feed_dict ={w1:13.0,w2:17.0}

#Now, access the op that you want to run. 
op_to_restore = graph.get_tensor_by_name("op_to_restore:0")

print sess.run(op_to_restore,feed_dict)
#This will print 60 which is calculated 

This and some more advanced use-cases have been explained very well here.

A quick complete tutorial to save and restore Tensorflow models


For tensorflow-2.0

it's very simple.

import tensorflow as tf

SAVE

model.save("model_name")

RESTORE

model = tf.keras.models.load_model('model_name')

If you use tf.train.MonitoredTrainingSession as the default session, you don't need to add extra code to do save/restore things. Just pass a checkpoint dir name to MonitoredTrainingSession's constructor, it will use session hooks to handle these.


Here's my simple solution for the two basic cases differing on whether you want to load the graph from file or build it during runtime.

This answer holds for Tensorflow 0.12+ (including 1.0).

Rebuilding the graph in code

Saving

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Loading

graph = ... # build the graph
saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    # now you can use the graph, continue training or whatever

Loading also the graph from a file

When using this technique, make sure all your layers/variables have explicitly set unique names. Otherwise Tensorflow will make the names unique itself and they'll be thus different from the names stored in the file. It's not a problem in the previous technique, because the names are "mangled" the same way in both loading and saving.

Saving

graph = ... # build the graph

for op in [ ... ]:  # operators you want to use after restoring the model
    tf.add_to_collection('ops_to_restore', op)

saver = tf.train.Saver()  # create the saver after the graph
with ... as sess:  # your session object
    saver.save(sess, 'my-model')

Loading

with ... as sess:  # your session object
    saver = tf.train.import_meta_graph('my-model.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    ops = tf.get_collection('ops_to_restore')  # here are your operators in the same order in which you saved them to the collection

Here is a simple example using Tensorflow 2.0 SavedModel format (which is the recommended format, according to the docs) for a simple MNIST dataset classifier, using Keras functional API without too much fancy going on:

# Imports
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Flatten
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt

# Load data
mnist = tf.keras.datasets.mnist # 28 x 28
(x_train,y_train), (x_test, y_test) = mnist.load_data()

# Normalize pixels [0,255] -> [0,1]
x_train = tf.keras.utils.normalize(x_train,axis=1)
x_test = tf.keras.utils.normalize(x_test,axis=1)

# Create model
input = Input(shape=(28,28), dtype='float64', name='graph_input')
x = Flatten()(input)
x = Dense(128, activation='relu')(x)
x = Dense(128, activation='relu')(x)
output = Dense(10, activation='softmax', name='graph_output', dtype='float64')(x)
model = Model(inputs=input, outputs=output)

model.compile(optimizer='adam',
             loss='sparse_categorical_crossentropy',
             metrics=['accuracy'])

# Train
model.fit(x_train, y_train, epochs=3)

# Save model in SavedModel format (Tensorflow 2.0)
export_path = 'model'
tf.saved_model.save(model, export_path)

# ... possibly another python program 

# Reload model
loaded_model = tf.keras.models.load_model(export_path) 

# Get image sample for testing
index = 0
img = x_test[index] # I normalized the image on a previous step

# Predict using the signature definition (Tensorflow 2.0)
predict = loaded_model.signatures["serving_default"]
prediction = predict(tf.constant(img))

# Show results
print(np.argmax(prediction['graph_output']))  # prints the class number
plt.imshow(x_test[index], cmap=plt.cm.binary)  # prints the image

What is serving_default?

It's the name of the signature def of the tag you selected (in this case, the default serve tag was selected). Also, here explains how to find the tag's and signatures of a model using saved_model_cli.

Disclaimers

This is just a basic example if you just want to get it up and running, but is by no means a complete answer - maybe I can update it in the future. I just wanted to give a simple example using the SavedModel in TF 2.0 because I haven't seen one, even this simple, anywhere.

@Tom's answer is a SavedModel example, but it will not work on Tensorflow 2.0, because unfortunately there are some breaking changes.

@Vishnuvardhan Janapati's answer says TF 2.0, but it's not for SavedModel format.


Following @Vishnuvardhan Janapati 's answer, here is another way to save and reload model with custom layer/metric/loss under TensorFlow 2.0.0

import tensorflow as tf
from tensorflow.keras.layers import Layer
from tensorflow.keras.utils.generic_utils import get_custom_objects

# custom loss (for example)  
def custom_loss(y_true,y_pred):
  return tf.reduce_mean(y_true - y_pred)
get_custom_objects().update({'custom_loss': custom_loss}) 

# custom loss (for example) 
class CustomLayer(Layer):
  def __init__(self, ...):
      ...
  # define custom layer and all necessary custom operations inside custom layer

get_custom_objects().update({'CustomLayer': CustomLayer})  

In this way, once you have executed such codes, and saved your model with tf.keras.models.save_model or model.save or ModelCheckpoint callback, you can re-load your model without the need of precise custom objects, as simple as

new_model = tf.keras.models.load_model("./model.h5"})

For TensorFlow version < 0.11.0RC1:

The checkpoints that are saved contain values for the Variables in your model, not the model/graph itself, which means that the graph should be the same when you restore the checkpoint.

Here's an example for a linear regression where there's a training loop that saves variable checkpoints and an evaluation section that will restore variables saved in a prior run and compute predictions. Of course, you can also restore variables and continue training if you'd like.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)

w = tf.Variable(tf.zeros([1, 1], dtype=tf.float32))
b = tf.Variable(tf.ones([1, 1], dtype=tf.float32))
y_hat = tf.add(b, tf.matmul(x, w))

...more setup for optimization and what not...

saver = tf.train.Saver()  # defaults to saving all variables - in this case w and b

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    if FLAGS.train:
        for i in xrange(FLAGS.training_steps):
            ...training loop...
            if (i + 1) % FLAGS.checkpoint_steps == 0:
                saver.save(sess, FLAGS.checkpoint_dir + 'model.ckpt',
                           global_step=i+1)
    else:
        # Here's where you're restoring the variables w and b.
        # Note that the graph is exactly as it was when the variables were
        # saved in a prior training run.
        ckpt = tf.train.get_checkpoint_state(FLAGS.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            saver.restore(sess, ckpt.model_checkpoint_path)
        else:
            ...no checkpoint found...

        # Now you can run the model to get predictions
        batch_x = ...load some data...
        predictions = sess.run(y_hat, feed_dict={x: batch_x})

Here are the docs for Variables, which cover saving and restoring. And here are the docs for the Saver.


My environment: Python 3.6, Tensorflow 1.3.0

Though there have been many solutions, most of them is based on tf.train.Saver. When we load a .ckpt saved by Saver, we have to either redefine the tensorflow network or use some weird and hard-remembered name, e.g. 'placehold_0:0','dense/Adam/Weight:0'. Here I recommend to use tf.saved_model, one simplest example given below, your can learn more from Serving a TensorFlow Model:

Save the model:

import tensorflow as tf

# define the tensorflow network and do some trains
x = tf.placeholder("float", name="x")
w = tf.Variable(2.0, name="w")
b = tf.Variable(0.0, name="bias")

h = tf.multiply(x, w)
y = tf.add(h, b, name="y")
sess = tf.Session()
sess.run(tf.global_variables_initializer())

# save the model
export_path =  './savedmodel'
builder = tf.saved_model.builder.SavedModelBuilder(export_path)

tensor_info_x = tf.saved_model.utils.build_tensor_info(x)
tensor_info_y = tf.saved_model.utils.build_tensor_info(y)

prediction_signature = (
  tf.saved_model.signature_def_utils.build_signature_def(
      inputs={'x_input': tensor_info_x},
      outputs={'y_output': tensor_info_y},
      method_name=tf.saved_model.signature_constants.PREDICT_METHOD_NAME))

builder.add_meta_graph_and_variables(
  sess, [tf.saved_model.tag_constants.SERVING],
  signature_def_map={
      tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          prediction_signature 
  },
  )
builder.save()

Load the model:

import tensorflow as tf
sess=tf.Session() 
signature_key = tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
input_key = 'x_input'
output_key = 'y_output'

export_path =  './savedmodel'
meta_graph_def = tf.saved_model.loader.load(
           sess,
          [tf.saved_model.tag_constants.SERVING],
          export_path)
signature = meta_graph_def.signature_def

x_tensor_name = signature[signature_key].inputs[input_key].name
y_tensor_name = signature[signature_key].outputs[output_key].name

x = sess.graph.get_tensor_by_name(x_tensor_name)
y = sess.graph.get_tensor_by_name(y_tensor_name)

y_out = sess.run(y, {x: 3.0})

In the new version of tensorflow 2.0, the process of saving/loading a model is a lot easier. Because of the Implementation of the Keras API, a high-level API for TensorFlow.

To save a model: Check the documentation for reference: https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/save_model

tf.keras.models.save_model(model_name, filepath, save_format)

To load a model:

https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/models/load_model

model = tf.keras.models.load_model(filepath)

As Yaroslav said, you can hack restoring from a graph_def and checkpoint by importing the graph, manually creating variables, and then using a Saver.

I implemented this for my personal use, so I though I'd share the code here.

Link: https://gist.github.com/nikitakit/6ef3b72be67b86cb7868

(This is, of course, a hack, and there is no guarantee that models saved this way will remain readable in future versions of TensorFlow.)


In most cases, saving and restoring from disk using a tf.train.Saver is your best option:

... # build your model
saver = tf.train.Saver()

with tf.Session() as sess:
    ... # train the model
    saver.save(sess, "/tmp/my_great_model")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

You can also save/restore the graph structure itself (see the MetaGraph documentation for details). By default, the Saver saves the graph structure into a .meta file. You can call import_meta_graph() to restore it. It restores the graph structure and returns a Saver that you can use to restore the model's state:

saver = tf.train.import_meta_graph("/tmp/my_great_model.meta")

with tf.Session() as sess:
    saver.restore(sess, "/tmp/my_great_model")
    ... # use the model

However, there are cases where you need something much faster. For example, if you implement early stopping, you want to save checkpoints every time the model improves during training (as measured on the validation set), then if there is no progress for some time, you want to roll back to the best model. If you save the model to disk every time it improves, it will tremendously slow down training. The trick is to save the variable states to memory, then just restore them later:

... # build your model

# get a handle on the graph nodes we need to save/restore the model
graph = tf.get_default_graph()
gvars = graph.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
assign_ops = [graph.get_operation_by_name(v.op.name + "/Assign") for v in gvars]
init_values = [assign_op.inputs[1] for assign_op in assign_ops]

with tf.Session() as sess:
    ... # train the model

    # when needed, save the model state to memory
    gvars_state = sess.run(gvars)

    # when needed, restore the model state
    feed_dict = {init_value: val
                 for init_value, val in zip(init_values, gvars_state)}
    sess.run(assign_ops, feed_dict=feed_dict)

A quick explanation: when you create a variable X, TensorFlow automatically creates an assignment operation X/Assign to set the variable's initial value. Instead of creating placeholders and extra assignment ops (which would just make the graph messy), we just use these existing assignment ops. The first input of each assignment op is a reference to the variable it is supposed to initialize, and the second input (assign_op.inputs[1]) is the initial value. So in order to set any value we want (instead of the initial value), we need to use a feed_dict and replace the initial value. Yes, TensorFlow lets you feed a value for any op, not just for placeholders, so this works fine.


Use tf.train.Saver to save a model, remerber, you need to specify the var_list, if you want to reduce the model size. The val_list can be tf.trainable_variables or tf.global_variables.


There are two parts to the model, the model definition, saved by Supervisor as graph.pbtxt in the model directory and the numerical values of tensors, saved into checkpoint files like model.ckpt-1003418.

The model definition can be restored using tf.import_graph_def, and the weights are restored using Saver.

However, Saver uses special collection holding list of variables that's attached to the model Graph, and this collection is not initialized using import_graph_def, so you can't use the two together at the moment (it's on our roadmap to fix). For now, you have to use approach of Ryan Sepassi -- manually construct a graph with identical node names, and use Saver to load the weights into it.

(Alternatively you could hack it by using by using import_graph_def, creating variables manually, and using tf.add_to_collection(tf.GraphKeys.VARIABLES, variable) for each variable, then using Saver)


You can save the variables in the network using

saver = tf.train.Saver() 
saver.save(sess, 'path of save/fileName.ckpt')

To restore the network for reuse later or in another script, use:

saver = tf.train.Saver()
saver.restore(sess, tf.train.latest_checkpoint('path of save/')
sess.run(....) 

Important points:

  1. sess must be same between first and later runs (coherent structure).
  2. saver.restore needs the path of the folder of the saved files, not an individual file path.

For tensorflow 2.0, it is as simple as

# Save the model
model.save('path_to_my_model.h5')

To restore:

new_model = tensorflow.keras.models.load_model('path_to_my_model.h5')

Examples related to python

programming a servo thru a barometer Is there a way to view two blocks of code from the same file simultaneously in Sublime Text? python variable NameError Why my regexp for hyphenated words doesn't work? Comparing a variable with a string python not working when redirecting from bash script is it possible to add colors to python output? Get Public URL for File - Google Cloud Storage - App Engine (Python) Real time face detection OpenCV, Python xlrd.biffh.XLRDError: Excel xlsx file; not supported Could not load dynamic library 'cudart64_101.dll' on tensorflow CPU-only installation

Examples related to tensorflow

Could not load dynamic library 'cudart64_101.dll' on tensorflow CPU-only installation Module 'tensorflow' has no attribute 'contrib' Tensorflow 2.0 - AttributeError: module 'tensorflow' has no attribute 'Session' Could not install packages due to an EnvironmentError: [WinError 5] Access is denied: How do I use TensorFlow GPU? Which TensorFlow and CUDA version combinations are compatible? Could not find a version that satisfies the requirement tensorflow pip3: command not found How to import keras from tf.keras in Tensorflow? Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX AVX2

Examples related to machine-learning

Error in Python script "Expected 2D array, got 1D array instead:"? How to predict input image using trained model in Keras? What is the role of "Flatten" in Keras? How to concatenate two layers in keras? How to save final model using keras? scikit-learn random state in splitting dataset Why binary_crossentropy and categorical_crossentropy give different performances for the same problem? What is the meaning of the word logits in TensorFlow? Can anyone explain me StandardScaler? Can Keras with Tensorflow backend be forced to use CPU or GPU at will?

Examples related to model

Eloquent ORM laravel 5 Get Array of ids Tensorflow: how to save/restore a model? Cannot overwrite model once compiled Mongoose AngularJS - Binding radio buttons to models with boolean values Accessing MVC's model property from Javascript MVC 4 - how do I pass model data to a partial view? How to load json into my angular.js ng-model? Check if an object exists CodeIgniter PHP Model Access "Unable to locate the model you have specified" Removing a model in rails (reverse of "rails g model Title...")