Experimental๏
This module provides experimental features. Use cautiously.
- class experimental.TorchPredictor(model, is_trained=False, optimizer=<class 'torch.optim.adam.Adam'>, criterion=MSELoss(), **compile_kwargs)๏
Bases:
object
Wrapper of a torch prediction model \(\hat{f}\). Enables to standardize the interface of torch predictors and to expose generic
fit()
,predict()
andcopy()
methods.- Parameters:
model (Any) โ torch prediction model \(\hat{f}\)
is_trained (bool) โ boolean flag that informs if the model is pre-trained. If True, the call to
fit()
will be skippedoptimizer (torch.optim.Optimizer) โ torch optimizer. Defaults to torch.optim.Adam.
criterion (torch.nn.modules.loss) โ criterion that measures the distance between predictions and outputs. Default to torch.nn.MSELoss.
compile_kwargs โ keyword arguments to be used if needed during the call
model.compile()
on the underlying model
Note
The model constructor has to take as argument both input_feat:int and output_feat:int, corresponding to the number of features (or channels) for each input and output, respectively.
- copy()๏
Returns a copy of the predictor.
- Returns:
copy of the predictor.
- Return type:
- fit(X, y, **kwargs)๏
Fit model to the training data.
- Parameters:
X (torch.Tensor) โ train features.
y (torch.Tensor) โ train labels.
kwargs โ keyword arguments to configure the training.
- predict(X)๏
Compute predictions on new examples.
- Parameters:
X (torch.Tensor) โ new examplesโ features.
- Returns:
predictions \(\hat{f}(X)\) associated to the new examples X.
- Return type:
torch.Tensor