Callbacks API

A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch, before or after a single batch, etc).

You can use callbacks to:

  • Write TensorBoard logs after every batch of training to monitor your metrics
  • Periodically save your model to disk
  • Do early stopping
  • Get a view on internal states and statistics of a model during training
  • ...and more

Usage of callbacks

You can pass a list of callbacks (as the keyword argument callbacks) to the .fit() method of a model:

my_callbacks = [
    k4t.callbacks.EarlyStopping(patience=2),
    k4t.callbacks.ModelCheckpoint('best_model.pt', monitor='val_acc'),
]
model.fit(x, y, epochs=10, callbacks=my_callbacks)

The relevant methods of the callbacks will then be called at each stage of the training.

Available callbacks

  • ModelCheckpoint
  • EarlyStopping
  • LRScheduler
  • LambdaCallback
  • CSVLogger

Custom callbacks

Subclass k4t.callbacks.Callback and override specific methods.

class CustomCallback(k4t.callbacks.Callback):
    def __init__(self):
        pass
    def on_epoch_begin(self, trainer):
        pass
    def on_epoch_end(self, trainer):
        pass
    def on_batch_begin(self, trainer):
        pass
    def on_batch_end(self, trainer):
        pass
    def on_train_begin(self, trainer):
        pass
    def on_train_end(self, trainer):
        pass

Or using LambdaCallback,

LambdaCallback(
    on_epoch_begin=None,
    on_epoch_end=None,
    on_batch_begin=None,
    on_batch_end=None,
    on_train_begin=None,
    on_train_end=None
)