Experimental๏
This module provides experimental features. Use cautiously.
- class experimental.TorchPredictor(model, is_trained=False, optimizer=None, criterion=None, **compile_kwargs)๏
Bases:
objectWrapper of a torch prediction model \(\hat{f}\). Enables to standardize the interface of torch predictors and to expose generic
fit(),predict()andcopy()methods.- Parameters:
model (Any) โ torch prediction model \(\hat{f}\)
is_trained (bool) โ boolean flag that informs if the model is pre-trained. If True, the call to
fit()will be skippedoptimizer (torch.optim.Optimizer) โ torch optimizer. Defaults to torch.optim.Adam.
criterion (torch.nn.modules.loss) โ criterion that measures the distance between predictions and outputs. Default to torch.nn.MSELoss.
compile_kwargs โ keyword arguments passed to the optimizer constructor when creating the optimizer in
fit().
- copy()๏
Returns a copy of the predictor.
- Returns:
copy of the predictor.
- Return type:
- fit(X, y, **kwargs)๏
Fit model to the training data.
- Parameters:
X (torch.Tensor) โ train features.
y (torch.Tensor) โ train labels.
kwargs โ keyword arguments to configure the training.
- predict(X)๏
Compute predictions on new examples.
- Parameters:
X (torch.Tensor) โ new examplesโ features.
- Returns:
predictions \(\hat{f}(X)\) associated to the new examples X.
- Return type:
torch.Tensor