bead.src.trainers package
Submodules
bead.src.trainers.inference module
Inference functionality for trained models.
This module provides functionality to perform inference using trained models on test data. It handles the loading of both background and signal data, preprocessing, and passing it through the model to get reconstructions and latent representations. The resulting metrics, reconstructions, and latent variables are saved for later analysis.
- Functions:
seed_worker: Sets seeds for workers to ensure reproducibility. infer: Main function for performing inference on test data.
- bead.src.trainers.inference.infer(data_bkg, data_sig, model_path, output_path, config, verbose: bool = False)[source]
Does the entire training loop by calling the fit() and validate(). Appart from this, this is the main function where the data is converted to the correct type for it to be trained, via torch.Tensor(). Furthermore, the batching is also done here, based on config.batch_size, and it is the torch.utils.data.DataLoader doing the splitting. Applying either EarlyStopping or LR Scheduler is also done here, all based on their respective config arguments. For reproducibility, the seeds can also be fixed in this function.
- Parameters:
data_bkg (Tuple) – Tuple containing the background data
data_sig (Tuple) – Tuple containing the signal data
model_path (string) – Path to the model directory
output_path (string) – Path to the output directory
config (dataClass) – Base class selecting user inputs
verbose (bool) – Verbose mode, default is False
- Returns:
True if inference was successful, False otherwise
- Return type:
bool
bead.src.trainers.training module
Training functionality for anomaly detection models. This module provides functionality for training neural network models for anomaly detection. It includes functions for model fitting, validation, and the main training loop that handles data loading, model initialization, optimization, and early stopping.
- Functions:
fit: Performs one epoch of training on the training set. validate: Evaluates the model on the validation set. seed_worker: Sets seeds for workers to ensure reproducibility. train: Main function that handles the entire training process.
- bead.src.trainers.training.fit(config, ddp_model, dataloader, loss_fn, optimizer, device, scaler, is_ddp_active, local_rank, epoch_num, verbose: bool = False)[source]
This function trains the model on the train set. It computes the losses and does the backwards propagation, and updates the optimizer as well.
- Parameters:
config (dataClass) – Base class selecting user inputs
ddp_model (modelObject) – The model you wish to train - explicit handling for DDP
dataloader (torch.DataLoader) – Defines the batched data which the model is trained on
loss_fn (lossObject) – Defines the loss function used to train the model
optimizer (torch.optim) – Chooses optimizer for gradient descent.
device (torch.device) – Chooses which device to use with torch
scaler (torch.cuda.amp.GradScaler) – Scaler for mixed precision training
is_ddp_active (bool) – Flag indicating if DDP is active
local_rank (int) – Local rank of the process in DDP
- Returns:
Training losses, Epoch_loss and trained model
- Return type:
list, model object
- bead.src.trainers.training.seed_worker(worker_id)[source]
PyTorch implementation to fix the seeds
- Parameters:
() (worker_id)
- bead.src.trainers.training.train(data, labels, output_path, config, verbose: bool = False)[source]
Processes the entire training loop by calling the fit() and validate(). Appart from this, this is the main function where the data is converted to the correct type for it to be trained, via torch.Tensor(). Furthermore, the batching is also done here, based on config.batch_size, and it is the torch.utils.data.DataLoader doing the splitting. Torch AMP and DDP are also implemented here, if the user has selected them in the config file. Applying either EarlyStopping or LR Scheduler is also done here, all based on their respective config arguments. For reproducibility, the seeds can also be fixed in this function using the deterministic_algorithm config flag.
- Parameters:
model (modelObject) – The model you wish to train
data (Tuple) – Tuple containing the training and validation data
labels (Tuple) – Tuple containing the training and validation labels
project_path (string) – Path to the project directory
config (dataClass) – Base class selecting user inputs
verbose (bool) – If True, prints additional information during training
- Returns:
fully trained model ready to perform inference
- Return type:
modelObject
- bead.src.trainers.training.validate(config, ddp_model, dataloader, loss_fn, device, is_ddp_active, local_rank, epoch_num, verbose: bool = False)[source]
Function used to validate the training. Not necessary for doing compression, but gives a good indication of wether the model selected is a good fit or not.
- Parameters:
config (dataClass) – Base class selecting user inputs
model (modelObject) – Defines the model one wants to validate. The model used here is passed directly from fit().
dataloader (torch.DataLoader) – Defines the batched data which the model is validated on
loss_fn (lossObject) – Defines the loss function used to train the model
device (torch.device) – Chooses which device to use with torch
is_ddp_active (bool) – Flag indicating if DDP is active
local_rank (int) – Local rank of the process in DDP
- Returns:
Validation loss
- Return type:
float