Module bastionlab.torch.learner
Classes
RemoteLearner()
-
Represents a remote model on the server along with hyperparameters to train and test it.
The remote learner accepts the model to be trained with a
RemoteLearner
.Args: client: A BastionAI client to be used to access server resources. model: A Pytorch nn.Module or a BastionAI gRPC protocol reference to a distant model. remote_dataset: A BastionAI remote dataloader. loss: The name of the loss to use for training the model, supported loss functions are "l2" and "cross_entropy". optimizer: The configuration of the optimizer to use during training, refer to the documentation of
OptimizerConfig
. device: Name of the device on which to train model. The list of supported devices may be obtained using theget_available_devices
endpoint of theBastionLabTorch
object. max_grad_norm: This specifies the clipping threshold for gradients in DP-SGD. metric_eps_per_batch: The privacy budget allocated to the disclosure of the loss of every batch. May be overriden by providing a global budget for the loss disclosure over the whole training on calling thefit
method. model_name: A name for the uploaded model. model_description: Provides additional description for the uploaded model. expand: Whether to expand model's weights prior to uploading it, or not. progress: Whether to display a tqdm progress bar or not.Methods
fit(self, nb_epochs: int, eps: Optional[float] = None, batch_size: Optional[int] = None, max_grad_norm: Optional[float] = None, lr: Optional[float] = None, metric_eps: Optional[float] = None, timeout: float = 60.0, poll_delay: float = 0.2, per_n_epochs_checkpoint: int = 0, per_n_steps_checkpoint: int = 0, resume: bool = False) ‑> None
-
Fits the uploaded model to the training dataset with given hyperparameters.
Args: nb_epocs: Specifies the number of epochs to train the model. eps: Specifies the global privacy budget for the DP-SGD algorithm. max_grad_norm: Overrides the default clipping threshold for gradients passed to the constructor. lr: Overrides the default learning rate of the optimizer config passed to the constructor. metric_eps: Global privacy budget for loss disclosure for the whole training that overrides the default per-batch budget. timeout: Timeout in seconds between two updates of the loss on the server side. When elapsed without updates, polling ends and the progress bar is terminated. poll_delay: Delay in seconds between two polling requests for the loss.
get_model(self) ‑> torch.nn.modules.module.Module
- Returns the model passed to the constructor with its weights updated with the weights obtained by training on the server.
test(self, test_dataset: Union[torch.utils.data.dataset.Dataset, bastionlab_pb2.Reference, ForwardRef(None)] = None, batch_size: Optional[int] = None, metric: Optional[str] = None, metric_eps: Optional[float] = None, timeout: int = 100, poll_delay: float = 0.2) ‑> None
-
Tests the remote model with the test dataloader provided in the
RemoteLearner
.Args: test_dataset: overrides the test dataset passed to the remote
RemoteDataset
constructor. metric: test metric name, if not providedm the training loss is used. Metrics available are loss functions andaccuracy
. metric_eps: Global privacy budget for metric disclosure for the whole testing procedure that overrides the default per-batch budget. timeout: Timeout in seconds between two updates of the metric on the server side. When elapsed without updates, polling ends and the progress bar is terminated. poll_delay: Delay in seconds between two polling requests for the metric.