20 June 2017

Display sample predictions during training using Keras callbacks

I have seen many sample Keras scripts where training a model by calling model.fit is allowed to finish after a couple of epochs to create an opportunity to compute and display a sample of the predictions the model generates. model.fit is then called again in a loop to continue training. Unfortunately, this technique confuses TensorBoard as it tries to trace how the training progresses and indicators like loss and accuracy change. This is beacuse restarting the fitting makes TensorBoard think that new data is coming in at previous timepoints, leading to some very confused graphs.

Fortunately, it is very easy to set up a new Keras callback to display sample predictions without having to stop the training using LambdaCallback. We can have a function called after each epoch to print out the info; it can even check the number of the epoch to do so not at the end of every epoch, but less frequently.

Using this callback together with the callback that logs data for TensorBoard is illusrated below. In this code we use an iterator to get the training data, but the same pattern can be used with pre-loaded data, too.