Elod P Csirmaz’s Blog: Display sample predictions during training using Keras callbacks

20 June 2017

Display sample predictions during training using Keras callbacks

I have seen many sample Keras scripts where training a model by calling model.fit is allowed to finish after a couple of epochs to create an opportunity to compute and display a sample of the predictions the model generates. model.fit is then called again in a loop to continue training. Unfortunately, this technique confuses TensorBoard as it tries to trace how the training progresses and indicators like loss and accuracy change. This is beacuse restarting the fitting makes TensorBoard think that new data is coming in at previous timepoints, leading to some very confused graphs.

Fortunately, it is very easy to set up a new Keras callback to display sample predictions without having to stop the training using LambdaCallback. We can have a function called after each epoch to print out the info; it can even check the number of the epoch to do so not at the end of every epoch, but less frequently.

Using this callback together with the callback that logs data for TensorBoard is illusrated below. In this code we use an iterator to get the training data, but the same pattern can be used with pre-loaded data, too.

4 comments:

  1. Hi, Elod is there a way to get the predictions without again calling model.predict, since the model has already computed the values whilst the neural network ran on the image

    ReplyDelete
    Replies
    1. Hi, thanks for your comment. It is indeed true. However, I see a few issues: during training the network processes a whole batch, so you'd need to extract the one output you're interested in; and it may be tricky to get hold of the output when using the convenience functions of Keras for training. It may be easier with just TensorFlow. I'd suggest calling predict because it shouldn't be too expensive, and it's usually not done too frequently anyway.

      Delete
  2. I don't understand what to put for data_iterator.

    ReplyDelete
    Replies
    1. Hi, in this example I used fit_generator, which uses a function (generator=data_iterator) that keeps returning the training data. You don't need to do this; instead of fit_generator, you can use model.fit, and then you can pass it your data as-is. Please see https://keras.io/models/model/#fit and https://keras.io/models/model/#fit_generator . This latter link explains how to write a suitable generator function. Hope this helps.

      Delete