Training API (Beta)
By default, the training pipeline of keras4torch
can handle many useful cases. While in specific situation, you may want to customize it. To this end, we provide some hooks.
You need to subclass k4t.configs.TrainerLoopConfig
and overwrite one or several hook methods.
Then pass a instance to Model.compile(..., loop_config)
class TrainerLoopConfig():
def __init__(self):
self._train = None
@property
def training(self) -> bool:
return self._train
def on_epoch_begin(self):
pass
def process_batch(self, batch):
*x_batch, y_batch = batch
return x_batch, y_batch
def forward_call(self, model, x_batch):
return model(*x_batch)
def prepare_for_optimizer_step(self, model):
pass
def prepare_for_metrics_update(self, y_batch_pred, y_batch):
return y_batch_pred, y_batch
def cache_for_epoch_metrics(self, y_batch_pred, y_batch):
return y_batch_pred.cpu(), y_batch.cpu()