Skip to content

Module bastionlab.torch

Sub-modules

Classes

BastionLabTorch()

BastionLab Torch RPC Handle

Methods

RemoteDataset(self, *args, **kwargs) ‑> bastionlab.torch.RemoteDataset

Returns a RemoteDataset object encapsulating a training and testing dataloaders on the remote server that uses this client to communicate with the server.

Args: *args: all arguments are forwarded to the bastionlab.torch.RemoteDataset constructor. **kwargs: all keyword arguments are forwarded to the bastionlab.torch.RemoteDataset constructor.

RemoteLearner(self, *args, **kwargs) ‑> bastionlab.torch.RemoteLearner

Returns a bastionlab.torch.RemoteLearner object encapsulating a model and hyperparameters for training and testing on the remote server and that uses this client to communicate with the server.

Args: *args: all arguments are forwarded to the bastionlab.torch.RemoteLearner constructor. **kwargs: all keyword arguments are forwarded to the bastionlab.torch.RemoteLearner constructor.

RemoteTensor(self, tensor: torch.Tensor) ‑> bastionlab.torch.RemoteTensor

Returns a RemoteTensor which represents a reference to the uploaded tensor.

Args: tensor: The tensor to be uploaded.

delete_dataset(self, ref: Union[ForwardRef('bastionlab.torch.RemoteDataset'), bastionlab_pb2.Reference])

Deletes the dataset correponding to the given ref reference on the BastionLab Torch server.

Args: ref: BastionLab Torch gRPC protocol reference of the dataset to be deleted.

delete_module(self, ref: bastionlab_pb2.Reference) ‑> None

Deletes the module correponding to the given ref reference on the BastionLab Torch server.

Args: ref: BastionLab Torch gRPC protocol reference of the module to be deleted.

fetch_dataset(self, ref: Union[ForwardRef('bastionlab.torch.RemoteDataset'), bastionlab_pb2.Reference]) ‑> [bastionlab.torch.utils](utils.md).TensorDataset

Fetches the distant dataset with a BastionLab Torch gRPC protocol reference.

Args: ref: BastionLab Torch gRPC protocol reference object corresponding to the distant dataset.

Returns: A dataset instance built from received data.

fetch_model_weights(self, model: torch.nn.modules.module.Module, ref: bastionlab_pb2.Reference) ‑> None

Fetches the weights of a distant trained model with a BastionLab Torch gRPC protocol reference and loads the weights into the passed model instance.

Args: model: The Pytorch's nn.Module whose weights will be replaced by the fetched weights. ref: BastionLab Torch gRPC protocol reference object corresponding to the distant trained model.

get_available_datasets(self) ‑> List[bastionlab_pb2.Reference]
Returns the list of BastionLab Torch gRPC protocol references of all datasets on the server.
get_available_devices(self) ‑> List[str]
Returns the list of devices available on the server.
get_available_models(self) ‑> List[bastionlab_pb2.Reference]
Returns the list of BastionLab Torch gRPC protocol references of all available models on the server.
get_available_optimizers(self) ‑> List[str]
Returns the list of optimizers supported by the server.
get_metric(self, run: bastionlab_pb2.Reference) ‑> bastionlab_torch_pb2.Metric

Returns the value of the metric associated with the given run reference.

Args: run: BastionLab Torch gRPC protocol reference of the run whose metric is read.

send_dataset(self, dataset: torch.utils.data.dataset.Dataset, name: str, description: str = '', privacy_limit: Optional[float] = None, chunk_size: int = 4194285, batch_size: int = 1024, train_dataset: Optional[bastionlab_pb2.Reference] = None, progress: bool = False) ‑> bastionlab_pb2.Reference

Uploads a Pytorch Dataset to the BastionLab Torch server.

Args: model: The Pytorch Dataset to upload. name: A name for the dataset being uploaded. description: A string description of the dataset being uploaded. chunk_size: Size of a chunk in the BastionLab Torch gRPC protocol in bytes. batch_size: Size of a unit of serialization in number of samples, increasing this value may increase serialization throughput at the price of a higher memory consumption. train_dataset: metadata, True means this dataset is suited for training, False that it should be used for testing/validating only

Returns: BastionLab Torch gRPC protocol's reference object.

send_model(self, model: torch.nn.modules.module.Module, name: str, description: str = '', chunk_size: int = 4194285, progress: bool = False) ‑> bastionlab_pb2.Reference

Uploads a Pytorch module to the BastionLab Torch server.

This endpoint transforms Pytorch modules into TorchScript modules and sends them to the BastionLab Torch server over gRPC.

Args: model: The Pytorch nn.Module to upload. name: A name for the module being uploaded. description: A string description of the module being uploaded. chunk_size: Size of a chunk in the BastionLab Torch gRPC protocol in bytes.

Returns: BastionLab Torch gRPC protocol's reference object.

RemoteDataset()

RemoteDataset(inputs: List[bastionlab.torch.data.RemoteTensor], labels: bastionlab.torch.data.RemoteTensor, name: Optional[str] = 'RemoteDataset', description: Optional[str] = 'RemoteDataset', privacy_limit: Optional[float] = -1.0, identifier: Optional[str] = '')

Class variables

description: Optional[str] :

identifier: Optional[str] :

inputs: List[[bastionlab.torch.data](data.md).RemoteTensor] :

labels: [bastionlab.torch.data](data.md).RemoteTensor :

name: Optional[str] :

privacy_limit: Optional[float] :

Instance variables

input_dtype: torch.dtype
The input dtype of the tensors stored
nb_samples: int
The number of samples in the RemoteDataset
RemoteLearner()

Represents a remote model on the server along with hyperparameters to train and test it.

The remote learner accepts the model to be trained with a RemoteLearner.

Args: client: A BastionAI client to be used to access server resources. model: A Pytorch nn.Module or a BastionAI gRPC protocol reference to a distant model. remote_dataset: A BastionAI remote dataloader. loss: The name of the loss to use for training the model, supported loss functions are "l2" and "cross_entropy". optimizer: The configuration of the optimizer to use during training, refer to the documentation of OptimizerConfig. device: Name of the device on which to train model. The list of supported devices may be obtained using the get_available_devices endpoint of the BastionLabTorch object. max_grad_norm: This specifies the clipping threshold for gradients in DP-SGD. metric_eps_per_batch: The privacy budget allocated to the disclosure of the loss of every batch. May be overriden by providing a global budget for the loss disclosure over the whole training on calling the fit method. model_name: A name for the uploaded model. model_description: Provides additional description for the uploaded model. expand: Whether to expand model's weights prior to uploading it, or not. progress: Whether to display a tqdm progress bar or not.

Methods

fit(self, nb_epochs: int, eps: Optional[float] = None, batch_size: Optional[int] = None, max_grad_norm: Optional[float] = None, lr: Optional[float] = None, metric_eps: Optional[float] = None, timeout: float = 60.0, poll_delay: float = 0.2, per_n_epochs_checkpoint: int = 0, per_n_steps_checkpoint: int = 0, resume: bool = False) ‑> None

Fits the uploaded model to the training dataset with given hyperparameters.

Args: nb_epocs: Specifies the number of epochs to train the model. eps: Specifies the global privacy budget for the DP-SGD algorithm. max_grad_norm: Overrides the default clipping threshold for gradients passed to the constructor. lr: Overrides the default learning rate of the optimizer config passed to the constructor. metric_eps: Global privacy budget for loss disclosure for the whole training that overrides the default per-batch budget. timeout: Timeout in seconds between two updates of the loss on the server side. When elapsed without updates, polling ends and the progress bar is terminated. poll_delay: Delay in seconds between two polling requests for the loss.

get_model(self) ‑> torch.nn.modules.module.Module
Returns the model passed to the constructor with its weights updated with the weights obtained by training on the server.
test(self, test_dataset: Union[torch.utils.data.dataset.Dataset, bastionlab_pb2.Reference, ForwardRef(None)] = None, batch_size: Optional[int] = None, metric: Optional[str] = None, metric_eps: Optional[float] = None, timeout: int = 100, poll_delay: float = 0.2) ‑> None

Tests the remote model with the test dataloader provided in the RemoteLearner.

Args: test_dataset: overrides the test dataset passed to the remote RemoteDataset constructor. metric: test metric name, if not providedm the training loss is used. Metrics available are loss functions and accuracy. metric_eps: Global privacy budget for metric disclosure for the whole testing procedure that overrides the default per-batch budget. timeout: Timeout in seconds between two updates of the metric on the server side. When elapsed without updates, polling ends and the progress bar is terminated. poll_delay: Delay in seconds between two polling requests for the metric.

RemoteTensor()

BastionLab reference to a PyTorch (tch) Tensor on the server.

It also stores a few basic information about the tensor (dtype, shape).

You can also change the dtype of the tensor through an API call

Instance variables

dtype: torch.dtype
Returns the torch dtype of the corresponding tensor

identifier: str :

shape
Returns the torch Size of the corresponding tensor

Methods

to(self, dtype: torch.dtype)

Performs Tensor dtype conversion.

Args: dtype: torch.dtype The resulting torch.dtype