Experimental๏ƒ

This module provides experimental features. Use cautiously.

class experimental.TorchPredictor(model, is_trained=False, optimizer=<class 'torch.optim.adam.Adam'>, criterion=MSELoss(), **compile_kwargs)๏ƒ

Bases: object

Wrapper of a torch prediction model \(\hat{f}\). Enables to standardize the interface of torch predictors and to expose generic fit(), predict() and copy() methods.

Parameters:
  • model (Any) โ€“ torch prediction model \(\hat{f}\)

  • is_trained (bool) โ€“ boolean flag that informs if the model is pre-trained. If True, the call to fit() will be skipped

  • optimizer (torch.optim.Optimizer) โ€“ torch optimizer. Defaults to torch.optim.Adam.

  • criterion (torch.nn.modules.loss) โ€“ criterion that measures the distance between predictions and outputs. Default to torch.nn.MSELoss.

  • compile_kwargs โ€“ keyword arguments to be used if needed during the call model.compile() on the underlying model

Note

The model constructor has to take as argument both input_feat:int and output_feat:int, corresponding to the number of features (or channels) for each input and output, respectively.

copy()๏ƒ

Returns a copy of the predictor.

Returns:

copy of the predictor.

Return type:

TorchPredictor

fit(X, y, **kwargs)๏ƒ

Fit model to the training data.

Parameters:
  • X (torch.Tensor) โ€“ train features.

  • y (torch.Tensor) โ€“ train labels.

  • kwargs โ€“ keyword arguments to configure the training.

predict(X)๏ƒ

Compute predictions on new examples.

Parameters:

X (torch.Tensor) โ€“ new examplesโ€™ features.

Returns:

predictions \(\hat{f}(X)\) associated to the new examples X.

Return type:

torch.Tensor