bead.src.trainers package

Submodules

bead.src.trainers.inference module

Inference functionality for trained models.

This module provides functionality to perform inference using trained models on test data. It handles the loading of both background and signal data, preprocessing, and passing it through the model to get reconstructions and latent representations. The resulting metrics, reconstructions, and latent variables are saved for later analysis.

Functions:

seed_worker: Sets seeds for workers to ensure reproducibility. infer: Main function for performing inference on test data.

bead.src.trainers.inference.infer(data_bkg, data_sig, model_path, output_path, config, verbose: bool = False)[source]

Does the entire training loop by calling the fit() and validate(). Appart from this, this is the main function where the data is converted to the correct type for it to be trained, via torch.Tensor(). Furthermore, the batching is also done here, based on config.batch_size, and it is the torch.utils.data.DataLoader doing the splitting. Applying either EarlyStopping or LR Scheduler is also done here, all based on their respective config arguments. For reproducibility, the seeds can also be fixed in this function.

Parameters:
  • data_bkg (Tuple) – Tuple containing the background data

  • data_sig (Tuple) – Tuple containing the signal data

  • model_path (string) – Path to the model directory

  • output_path (string) – Path to the output directory

  • config (dataClass) – Base class selecting user inputs

  • verbose (bool) – Verbose mode, default is False

Returns:

True if inference was successful, False otherwise

Return type:

bool

bead.src.trainers.inference.seed_worker(worker_id)[source]

PyTorch implementation to fix the seeds :param worker_id ():

bead.src.trainers.training module

Training functionality for anomaly detection models. This module provides functionality for training neural network models for anomaly detection. It includes functions for model fitting, validation, and the main training loop that handles data loading, model initialization, optimization, and early stopping.

Functions:

fit: Performs one epoch of training on the training set. validate: Evaluates the model on the validation set. seed_worker: Sets seeds for workers to ensure reproducibility. train: Main function that handles the entire training process.

bead.src.trainers.training.fit(config, ddp_model, dataloader, loss_fn, optimizer, device, scaler, is_ddp_active, local_rank, epoch_num, verbose: bool = False)[source]

This function trains the model on the train set. It computes the losses and does the backwards propagation, and updates the optimizer as well.

Parameters:
  • config (dataClass) – Base class selecting user inputs

  • ddp_model (modelObject) – The model you wish to train - explicit handling for DDP

  • dataloader (torch.DataLoader) – Defines the batched data which the model is trained on

  • loss_fn (lossObject) – Defines the loss function used to train the model

  • optimizer (torch.optim) – Chooses optimizer for gradient descent.

  • device (torch.device) – Chooses which device to use with torch

  • scaler (torch.cuda.amp.GradScaler) – Scaler for mixed precision training

  • is_ddp_active (bool) – Flag indicating if DDP is active

  • local_rank (int) – Local rank of the process in DDP

Returns:

Training losses, Epoch_loss and trained model

Return type:

list, model object

bead.src.trainers.training.seed_worker(worker_id)[source]

PyTorch implementation to fix the seeds

Parameters:

() (worker_id)

bead.src.trainers.training.train(data, labels, output_path, config, verbose: bool = False)[source]

Processes the entire training loop by calling the fit() and validate(). Appart from this, this is the main function where the data is converted to the correct type for it to be trained, via torch.Tensor(). Furthermore, the batching is also done here, based on config.batch_size, and it is the torch.utils.data.DataLoader doing the splitting. Torch AMP and DDP are also implemented here, if the user has selected them in the config file. Applying either EarlyStopping or LR Scheduler is also done here, all based on their respective config arguments. For reproducibility, the seeds can also be fixed in this function using the deterministic_algorithm config flag.

Parameters:
  • model (modelObject) – The model you wish to train

  • data (Tuple) – Tuple containing the training and validation data

  • labels (Tuple) – Tuple containing the training and validation labels

  • project_path (string) – Path to the project directory

  • config (dataClass) – Base class selecting user inputs

  • verbose (bool) – If True, prints additional information during training

Returns:

fully trained model ready to perform inference

Return type:

modelObject

bead.src.trainers.training.validate(config, ddp_model, dataloader, loss_fn, device, is_ddp_active, local_rank, epoch_num, verbose: bool = False)[source]

Function used to validate the training. Not necessary for doing compression, but gives a good indication of wether the model selected is a good fit or not.

Parameters:
  • config (dataClass) – Base class selecting user inputs

  • model (modelObject) – Defines the model one wants to validate. The model used here is passed directly from fit().

  • dataloader (torch.DataLoader) – Defines the batched data which the model is validated on

  • loss_fn (lossObject) – Defines the loss function used to train the model

  • device (torch.device) – Chooses which device to use with torch

  • is_ddp_active (bool) – Flag indicating if DDP is active

  • local_rank (int) – Local rank of the process in DDP

Returns:

Validation loss

Return type:

float

Module contents