Skip to content

Module bastionlab.torch.utils

Classes

MultipleOutputWrapper(module:ย torch.nn.modules.module.Module, output:ย intย =ย 0)

Utility wrapper to select one output of a model with multiple outputs.

Args: module: A model with more than one outputs. output: Index of the output to retain.

Initializes internal Module state, shared by both nn.Module and ScriptModule.

Ancestors (in MRO)

  • torch.nn.modules.module.Module

Class variables

dump_patches: bool :

training: bool :

Methods

forward(self, *args, **kwargs) โ€‘> torch.Tensor

Defines the computation performed at every call.

Should be overridden by all subclasses.

.. note:: Although the recipe for forward pass needs to be defined within this function, one should call the :class:Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

TensorDataset(columns:ย List[torch.Tensor], labels:ย Optional[torch.Tensor])

A simple dataset compliant with Torch's Dataset build upon tensors representing columns and labels.

Args: columns: Tensors that represent the clolumns of the dataset (a column contains the values for a given input for all samples). labels: A tensor containing the labels of all inputs.

Ancestors (in MRO)

  • torch.utils.data.dataset.Dataset
  • typing.Generic

Methods

__getitem__(self, idx:ย int) โ€‘> Tuple[List[torch.Tensor],ย Optional[torch.Tensor]] :

__len__(self) โ€‘> int :