Experimental๏ƒ

This module provides experimental features. Use cautiously.

class experimental.TorchPredictor(model, is_trained=False, optimizer=None, criterion=None, **compile_kwargs)๏ƒ

Bases: object

Wrapper of a torch prediction model \(\hat{f}\). Enables to standardize the interface of torch predictors and to expose generic fit(), predict() and copy() methods.

Parameters:
  • model (Any) โ€“ torch prediction model \(\hat{f}\)

  • is_trained (bool) โ€“ boolean flag that informs if the model is pre-trained. If True, the call to fit() will be skipped

  • optimizer (torch.optim.Optimizer) โ€“ torch optimizer. Defaults to torch.optim.Adam.

  • criterion (torch.nn.modules.loss) โ€“ criterion that measures the distance between predictions and outputs. Default to torch.nn.MSELoss.

  • compile_kwargs โ€“ keyword arguments passed to the optimizer constructor when creating the optimizer in fit().

copy()๏ƒ

Returns a copy of the predictor.

Returns:

copy of the predictor.

Return type:

TorchPredictor

fit(X, y, **kwargs)๏ƒ

Fit model to the training data.

Parameters:
  • X (torch.Tensor) โ€“ train features.

  • y (torch.Tensor) โ€“ train labels.

  • kwargs โ€“ keyword arguments to configure the training.

predict(X)๏ƒ

Compute predictions on new examples.

Parameters:

X (torch.Tensor) โ€“ new examplesโ€™ features.

Returns:

predictions \(\hat{f}(X)\) associated to the new examples X.

Return type:

torch.Tensor