Models API
keras4torch.Model wraps a torch.nn.Module to integrate training and inference features.
Configs
.compile(optimizer, loss, metrics, ...)
Configure the model for training.
optimizer: String (name of optimizer) or optimizer instance.loss: String (name of objective function), objective function or loss instance.metrics: List of metrics to be evaluated by the model during training. You can also use dict to specify the abbreviation of each metric.epoch_metrics: List of non-linear metrics(e.g. ROC_AUC) that need to be evaluated on epoch end.device: Device of the model and its trainer, ifNone'cuda' will be used whentorch.cuda.is_available()otherwise 'cpu'.loop_config: OptionalTrainerLoopConfigobject to customize training and validation loop. See Training API for details.
NumPy workflow
.fit(x, y, epochs, batch_size, ...)
Train the model for a fixed number of epochs (iterations on a dataset).
x(ndarrayortorch.TensororDataset): Input datay(ndarrayortorch.Tensor): Target dataepochs(int, default=10): Number of epochs to train the modelbatch_size(int, default=32): Number of samples per gradient updatevalidation_batch_size(int, default=None): Number of samples for each step on validation loop, ifNonewill usebatch_sizevalidation_split(float between 0 and 1): Fraction of the training data to be used as validation datashuffle_val_split(bool, default=True): Whether to do shuffling whenvalidation_splitis providedvalidation_data(tuple ofxandyorDataset): Data on which to evaluate the loss and any model metrics at the end of each epochcallbacks(list ofkeras4torch.callbacks.Callback): List of callbacks to apply during trainingverbose(int, default=1): 0, 1, or 2. Verbosity mode. 0 = silent, 1 = normal, 2 = briefshuffle(bool, default=True): Whether to shuffle the training data before each epochsample_weight(list of floats): Optional weights for the training samples. If provided will enableWeightedRandomSamplernum_workers(int, default=0): Workers ofDataLoader. If-1will usecpu_count() - 1for multiprocessinguse_amp(bool, default=False): Whether to use automatic mixed precisionaccum_grad_steps(int, default=1): How many steps to update the model parameters**dl_kwargs: Extra keyword arguments for DataLoader
.evaluate(x, y, batch_size, ...)
Return the loss value & metrics values for the model in test mode.
x(ndarrayortorch.TensororDataset): Input datay(ndarrayortorch.Tensor): Target databatch_size(int, default=32): Number of samples per batchnum_workers(int, default=0): Workers ofDataLoader. If-1will usecpu_count() - 1for multiprocessinguse_amp(bool, default=False): Whether to use automatic mixed precision**dl_kwargs: Extra keyword arguments for DataLoader
.predict(x, batch_size, ...)
Generate output predictions for the input samples.
x(ndarrayortorch.TensororDataset): Input databatch_size(int, default=32): Number of samples per batchdevice(default=None): Device to do inferenceoutput_numpy(bool, default=True): IfTrue, the output will move to CPU and convert to NumPy arrayactivation(Callable or str, default=None): Extra activation applied to the output tensornum_workers(int, default=0): Workers ofDataLoader. If-1will usecpu_count() - 1for multiprocessinguse_amp(bool, default=False): Whether to use automatic mixed precision**dl_kwargs: Extra keyword arguments for DataLoader
DataLoader workflow
.fit_dl(train_loader, val_loader, epochs, ...)
.evaluate_dl(data_loader, ...)
.predict_dl(data_loader, ...)
Saving & Serialization
.save_weights(filepath)
Equal to torch.save(model.state_dict(), filepath).
.load_weights(filepath)
Equal to model.load_state_dict(torch.load(filepath)).
Utilities
.summary(depth, ...)
Print a string summary of the network.
depth(default=3): Summary details level
.count_params()
Count the total number of scalars composing the weights.
.trainable_params()
Return all trainable parameters of the model.