Introduction

Last blog, we used callbacks for the Optimizer class, which we called optimizer callbacks.

Callbacks are pieces of code that you inject into another piece of code at some predefined point instead of directly adding to the source code. In other words, the source code is expected to call back the callback at that predefined point if that callback is passed as an argument.

Injecting code is much easier than altering the source code since you won't have to copy paste the code from the library and instead can just write your own callbacks.

Fastai is different from different libraries since they allow callbacks to be able to read, modify, and control every possible information and process available in the training loop.

How to create a Callback

A Callback, unlike an optimizer callback, is for the training loop and has access to many events, which can be found as attributes through the event variable in case you forget. The full list can be found here.

In creating a callback, we can subclass Callback. For example, when we covered RNNs, we used ModelResetter, which we said called reset at the start of training and validation for each epoch. In code, it becomes:

class ModelResetter(Callback):
    "`Callback` that resets the model at each validation/training step"
    def before_train(self):    self.model.reset()
    def before_validate(self): self.model.reset()
    def after_fit(self):       self.model.reset()

Another example is RNNRegularizer, which handled activation and temporal activation regularizations:

class RNNRegularizer(Callback):
    "Add AR and TAR regularization"
    order,run_valid = RNNCallback.order+1,False
    def __init__(self, alpha=0., beta=0.): store_attr()
    def after_loss(self):
        if not self.training: return
        if self.alpha: self.learn.loss_grad += self.alpha * self.rnn.out.float().pow(2).mean()
        if self.beta:
            h = self.rnn.raw_out
            if len(h)>1: self.learn.loss_grad += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()

Callbacks also have access to different attributes of Learner. Some of them have shortcuts when reading like self.model instead of self.learn.model, but we have to use the long form when writing; so, self.learn.loss_grad = instead of self.loss_grad =.

Next, we can also use callbacks to control the training process through skipping a batch, epoch, or stop training altogether. These callbacks use different interrupts like CancelBatchException that skips the rest of the batch and goes to after_cancel_batch before after_batch.

Finally, we can adjust the order of callbacks through order which you specify outside of function definitions in the class. For example RNNRegularizer will always be called after RNNCallback since its order is +1 of the order of RNNCallback. Another example is TerminateOnNaNCallback which has a set order of -9:

class TerminateOnNaNCallback(Callback):
    "A `Callback` that terminates training if loss is NaN."
    order=-9
    def after_batch(self):
        "Test if `last_loss` is NaN and interrupts training."
        if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException

Instead of order, which is more strict, we can also use run_before and run_after like:

class TerminateOnNaNCallback(Callback):
    run_before=Recorder
    def after_batch(self):
        if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException

Where, in this case, ensures that TerminateOnNaNCallback is executed before the Recorder callback.

Conclusion

Instead of modifying the source code, callbacks allow us to inject code in predefined points of the source code without losing flexibility (if the library allows callbacks to be flexible like fastai). We can use callbacks to read and write information in the training loop and control the flow of the training loop.