A.I, Data and Software Engineering

Save, restore, visualise Graph with TensorFlow v2.0 & KERAS

S

TensorFlow 2.0 is coming really soon. Therefore, we quickly show some useful features, i.e., save and load a pre-trained model, with v.2 syntax. To make it more intuitive, we will also visualise the graph of the neural network model.

Benefits of saving a model

Quick answer: to save time, easy-share, and fast deploy.

A SavedModel contains a complete TensorFlow program, including weights and computation. It does not require the original model building code to run, which makes it useful for sharing or deploying with different platforms, e.g. TFLite(mobile / IoT), TensorFlow.js (Browsers), TensorFlow Serving (servers)…

With TensorFlow and Keras, we can easily save and restore models, custom models, and sessions. The basic steps are:

  • Create a model
  • Train the model
  • Save the model
  • Share and restore to use.

To demonstrate we will quickly create a sequential neural network using Keras and MNIST fashion dataset. You can try with CIFAR dataset as in this article.

Create a model with Keras and MNIST dataset

Import libraries and enable TensorFlow 2.0

from __future__ import absolute_import, division, print_function, unicode_literals
try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass
import tensorflow as tf
from matplotlib import pyplot as plt
import numpy as np

Output: TensorFlow 2.x selected.

Import fashion MNIST

We also create train sets and validation sets.

from tensorflow import keras
(x_train, y_train), (x_val, y_val) = keras.datasets.fashion_mnist.load_data()

From the training and validation sets, we can reformat them to tf.data.Dataset by using the two helper methods:

def preprocess(x, y):
  x = tf.cast(x, tf.float32) / 255.0
  y = tf.cast(y, tf.int64)
  return x, y
def create_dataset(xs, ys, n_classes=10):
  ys = tf.one_hot(ys, depth=n_classes)
  return tf.data.Dataset.from_tensor_slices((xs, ys)) \
    .map(preprocess) \
    .shuffle(len(ys)) \
    .batch(128)
train_dataset = create_dataset(x_train, y_train)
val_dataset = create_dataset(x_val, y_val)

Create and build a Keras sequential model

# we stack 4 dense layers
model = keras.Sequential([
    keras.layers.Reshape(target_shape=(28 * 28,), input_shape=(28, 28)),
    keras.layers.Dense(units=256, activation='relu'),
    keras.layers.Dense(units=192, activation='relu'),
    keras.layers.Dense(units=128, activation='relu'),
    keras.layers.Dense(units=10, activation='softmax')
])
#Compile the model
model.compile(optimizer='adam',
              loss=tf.losses.CategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

To visualise the model later, we define the Keras callback, in which ‘logdir’ is where we store the graph info.

# Define the Keras TensorBoard callback.
logdir="logs/fit/"
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

Train the model

history = model.fit(
    train_dataset.repeat(),
    epochs=10,
    steps_per_epoch=500,
    validation_data=val_dataset.repeat(),
    validation_steps=2,
    callbacks=[tensorboard_callback]
)

Output:

Epoch 9/10
500/500 [==============================] - 7s 14ms/step - loss: 1.5904 - accuracy: 0.8704 - val_loss: 1.6009 - val_accuracy: 0.8672
Epoch 10/10
500/500 [==============================] - 8s 16ms/step - loss: 1.5895 - accuracy: 0.8713 - val_loss: 1.6156 - val_accuracy: 0.8477

Before saving the model, let’s see its current description using “model.summary()“:

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #
=================================================================
reshape (Reshape)            (None, 784)               0
_________________________________________________________________
dense (Dense)                (None, 256)               200960
_________________________________________________________________
dense_1 (Dense)              (None, 192)               49344
_________________________________________________________________
dense_2 (Dense)              (None, 128)               24704
_________________________________________________________________
dense_3 (Dense)              (None, 10)                1290
=================================================================
Total params: 276,298
Trainable params: 276,298
Non-trainable params: 0

Save the model

Note that save_format: Either ‘tf’ or ‘h5’, indicating whether to save the model
to Tensorflow SavedModel or HDF5. The default is currently ‘h5’ in TensorFlow 1.*, but it is now ‘tf’ in TensorFlow 2.0.

model.save(filepath='./pretrain/', overwrite=True, save_format='tf')

We save the model in the ‘pretrain’ directory. Now, let check the saved model using ‘saved_model_cli’

!saved_model_cli show --dir ./pretrain/ --tag_set serve --signature_def serving_default

Output:

The given SavedModel SignatureDef contains the following input(s):
  inputs['reshape_input'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 28, 28)
      name: serving_default_reshape_input:0
The given SavedModel SignatureDef contains the following output(s):
  outputs['dense_3'] tensor_info:
      dtype: DT_FLOAT
      shape: (-1, 10)
      name: StatefulPartitionedCall:0
Method name is: tensorflow/serving/predict

What is in the ‘pretrain’ folder:

We have 2 folders, “assets” and “variables”, and one file “save_model.pb”.

./pretrain:
total 132
drwxr-xr-x 2 root root   4096 Oct  8 08:52 assets
-rw-r--r-- 1 root root 118433 Oct  8 08:52 saved_model.pb
drwxr-xr-x 2 root root   4096 Oct  8 08:52 variables
./pretrain/assets:
./pretrain/variables:
total 3256
-rw-r--r-- 1 root root 3321771 Oct  8 08:52 variables.data-00000-of-00001
-rw-r--r-- 1 root root    2148 Oct  8 08:52 variables.index

Restore the saved model

Loading a model from the ‘pretrain’ directory is as simple as saving the model.

loaded_model  = keras.models.load_model('./pretrain/') # Done loading
#check the structrure of the loaded_model if you want.
loaded_model.summary()

Get the predicted result from the restored model:

loaded_model.predict(val_dataset)

Output:

array([[7.41734179e-18, 4.21552374e-23, 5.25638975e-27, ...,
        1.00000000e+00, 1.66129058e-23, 5.22419891e-22],
       [1.03594854e-13, 5.41053513e-09, 1.00000000e+00, ...,
        8.35297819e-18, 1.34424147e-11, 7.73176926e-15],
       [1.76525968e-16, 1.02498872e-14, 1.03041346e-16, ...,
        6.16336324e-15, 5.33568415e-12, 1.33871903e-16],
       ...,
       [1.15577698e-20, 6.69796064e-28, 6.52123876e-31, ...,
        1.00000000e+00, 4.48179845e-27, 9.88130227e-25],
       [1.12627258e-19, 9.63445984e-23, 2.99525816e-23, ...,
        3.01945410e-22, 1.00000000e+00, 5.47376927e-28],
       [6.66963896e-19, 1.56428669e-22, 6.76676873e-26, ...,
        6.58221911e-24, 1.00000000e+00, 7.36716641e-28]], dtype=float32)

Visualise the model with Tensorboard

You can use TensorBoard to visualize your TensorFlow graph, plot quantitative metrics about the execution of your graph, and show additional data like images that pass through it. I find it so powerful and really enjoyable to play with.

If you added the Keras callback as mentioned in the previous section, you will be able to use the Tensorboard embedded to Jupyter notebook.

# Load the TensorBoard notebook extension.
%load_ext tensorboard
%tensorboard --logdir logs
Visualise the Keras model with Tensorboard
Visualise the model with Tensorboard

And if you want to see the accuracy and loss graph, you can switch to the “SCALARS” tab.

Visualise Train/Loss of Keras model with TensorBoard
Visualise Train/Loss with TensorBoard

Conclusion

In this article, we have demonstrated how easy to save, load, and visualise a model with Keras and TensorBoard. We did not focus on perfecting the model as it was for demo purposes. There are also several changes in TensorFlow v2 that we have not mentioned in this article but may cover some of the most exciting parts in the future posts.

2 comments

A.I, Data and Software Engineering

PetaMinds focuses on developing the coolest topics in data science, A.I, and programming, and make them so digestible for everyone to learn and create amazing applications in a short time.

Categories