How to use fastai's training loop Callbacks
Calling back, but in code.
Last blog, we used callbacks for the Optimizer
class, which we called optimizer callbacks.
Callbacks are pieces of code that you inject into another piece of code at some predefined point instead of directly adding to the source code. In other words, the source code is expected to call back the callback at that predefined point if that callback is passed as an argument.
Injecting code is much easier than altering the source code since you won't have to copy paste the code from the library and instead can just write your own callbacks.
Fastai is different from different libraries since they allow callbacks to be able to read, modify, and control every possible information and process available in the training loop.
A Callback
, unlike an optimizer callback, is for the training loop and has access to many events, which can be found as attributes through the event
variable in case you forget. The full list can be found here.
In creating a callback, we can subclass Callback
. For example, when we covered RNNs, we used ModelResetter
, which we said called reset
at the start of training and validation for each epoch. In code, it becomes:
class ModelResetter(Callback):
"`Callback` that resets the model at each validation/training step"
def before_train(self): self.model.reset()
def before_validate(self): self.model.reset()
def after_fit(self): self.model.reset()
Another example is RNNRegularizer
, which handled activation and temporal activation regularizations:
class RNNRegularizer(Callback):
"Add AR and TAR regularization"
order,run_valid = RNNCallback.order+1,False
def __init__(self, alpha=0., beta=0.): store_attr()
def after_loss(self):
if not self.training: return
if self.alpha: self.learn.loss_grad += self.alpha * self.rnn.out.float().pow(2).mean()
if self.beta:
h = self.rnn.raw_out
if len(h)>1: self.learn.loss_grad += self.beta * (h[:,1:] - h[:,:-1]).float().pow(2).mean()
Callback
s also have access to different attributes of Learner
. Some of them have shortcuts when reading like self.model
instead of self.learn.model
, but we have to use the long form when writing; so, self.learn.loss_grad =
instead of self.loss_grad =
.
Next, we can also use callbacks to control the training process through skipping a batch, epoch, or stop training altogether. These callbacks use different interrupts like CancelBatchException
that skips the rest of the batch and goes to after_cancel_batch
before after_batch
.
Finally, we can adjust the order of callbacks through order
which you specify outside of function definitions in the class. For example RNNRegularizer
will always be called after RNNCallback
since its order is +1
of the order of RNNCallback
. Another example is TerminateOnNaNCallback
which has a set order of -9
:
class TerminateOnNaNCallback(Callback):
"A `Callback` that terminates training if loss is NaN."
order=-9
def after_batch(self):
"Test if `last_loss` is NaN and interrupts training."
if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException
Instead of order
, which is more strict, we can also use run_before
and run_after
like:
class TerminateOnNaNCallback(Callback):
run_before=Recorder
def after_batch(self):
if torch.isinf(self.loss) or torch.isnan(self.loss): raise CancelFitException
Where, in this case, ensures that TerminateOnNaNCallback
is executed before the Recorder
callback.
Instead of modifying the source code, callbacks allow us to inject code in predefined points of the source code without losing flexibility (if the library allows callbacks to be flexible like fastai). We can use callbacks to read and write information in the training loop and control the flow of the training loop.