Skip to content

Module bastionlab.torch.learner

Classes

RemoteLearner()

Represents a remote model on the server along with hyperparameters to train and test it.

The remote learner accepts the model to be trained with a RemoteLearner.

Args: client: A BastionAI client to be used to access server resources. model: A Pytorch nn.Module or a BastionAI gRPC protocol reference to a distant model. remote_dataset: A BastionAI remote dataloader. loss: The name of the loss to use for training the model, supported loss functions are "l2" and "cross_entropy". optimizer: The configuration of the optimizer to use during training, refer to the documentation of OptimizerConfig. device: Name of the device on which to train model. The list of supported devices may be obtained using the get_available_devices endpoint of the BastionLabTorch object. max_grad_norm: This specifies the clipping threshold for gradients in DP-SGD. metric_eps_per_batch: The privacy budget allocated to the disclosure of the loss of every batch. May be overriden by providing a global budget for the loss disclosure over the whole training on calling the fit method. model_name: A name for the uploaded model. model_description: Provides additional description for the uploaded model. expand: Whether to expand model's weights prior to uploading it, or not. progress: Whether to display a tqdm progress bar or not.

Methods

fit(self, nb_epochs: int, eps: Optional[float] = None, batch_size: Optional[int] = None, max_grad_norm: Optional[float] = None, lr: Optional[float] = None, metric_eps: Optional[float] = None, timeout: float = 60.0, poll_delay: float = 0.2, per_n_epochs_checkpoint: int = 0, per_n_steps_checkpoint: int = 0, resume: bool = False) ‑> None

Fits the uploaded model to the training dataset with given hyperparameters.

Args: nb_epocs: Specifies the number of epochs to train the model. eps: Specifies the global privacy budget for the DP-SGD algorithm. max_grad_norm: Overrides the default clipping threshold for gradients passed to the constructor. lr: Overrides the default learning rate of the optimizer config passed to the constructor. metric_eps: Global privacy budget for loss disclosure for the whole training that overrides the default per-batch budget. timeout: Timeout in seconds between two updates of the loss on the server side. When elapsed without updates, polling ends and the progress bar is terminated. poll_delay: Delay in seconds between two polling requests for the loss.

get_model(self) ‑> torch.nn.modules.module.Module
Returns the model passed to the constructor with its weights updated with the weights obtained by training on the server.
test(self, test_dataset: Union[torch.utils.data.dataset.Dataset, bastionlab_pb2.Reference, ForwardRef(None)] = None, batch_size: Optional[int] = None, metric: Optional[str] = None, metric_eps: Optional[float] = None, timeout: int = 100, poll_delay: float = 0.2) ‑> None

Tests the remote model with the test dataloader provided in the RemoteLearner.

Args: test_dataset: overrides the test dataset passed to the remote RemoteDataset constructor. metric: test metric name, if not providedm the training loss is used. Metrics available are loss functions and accuracy. metric_eps: Global privacy budget for metric disclosure for the whole testing procedure that overrides the default per-batch budget. timeout: Timeout in seconds between two updates of the metric on the server side. When elapsed without updates, polling ends and the progress bar is terminated. poll_delay: Delay in seconds between two polling requests for the metric.