Source code for bead.src.trainers.inference

# Copyright 2022 Baler Contributors

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import random
import time
import sys
import numpy as np
from tqdm.rich import tqdm
import warnings
from tqdm import TqdmExperimentalWarning

from torch.nn import functional as F
import torch
from torch.utils.data import DataLoader, ConcatDataset

from ..utils import helper, loss, diagnostics


warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)


[docs] def seed_worker(worker_id): """PyTorch implementation to fix the seeds Args: worker_id (): """ worker_seed = torch.initial_seed() % 2**32 np.random.seed(worker_seed) random.seed(worker_seed)
[docs] def infer( events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig, model_path, output_path, config, verbose: bool = False, ): """Does the entire training loop by calling the `fit()` and `validate()`. Appart from this, this is the main function where the data is converted to the correct type for it to be trained, via `torch.Tensor()`. Furthermore, the batching is also done here, based on `config.batch_size`, and it is the `torch.utils.data.DataLoader` doing the splitting. Applying either `EarlyStopping` or `LR Scheduler` is also done here, all based on their respective `config` arguments. For reproducibility, the seeds can also be fixed in this function. Args: model (modelObject): The model you wish to train data (Tuple): Tuple containing the training and validation data project_path (string): Path to the project directory config (dataClass): Base class selecting user inputs Returns: modelObject: fully trained model ready to perform compression and decompression """ # Print input shapes if verbose: print("Events - bkg shape: ", events_bkg.shape) print("Jets - bkg shape: ", jets_bkg.shape) print("Constituents - bkg shape: ", constituents_bkg.shape) print("Events - sig shape: ", events_sig.shape) print("Jets - sig shape: ", jets_sig.shape) print("Constituents - sig shape: ", constituents_sig.shape) # Get the device and move tensors to the device device = helper.get_device() labeled_data = ( events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig, ) ( events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig, ) = [ x.to(device) for x in labeled_data ] # Split data and labels if verbose: print("Splitting data and labels") data, labels = helper.data_label_split(labeled_data) # Reshape tensors to pass to conv layers ( events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig, ) = data ( events_bkg_label, jets_bkg_label, constituents_bkg_label, events_sig_label, jets_sig_label, constituents_sig_label, ) = labels # Reshape tensors to pass to conv layers if "ConvVAE" in config.model_name or "ConvAE" in config.model_name: ( events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig, ) = [ x.unsqueeze(1).float() for x in [events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig] ] data = ( events_bkg, jets_bkg, constituents_bkg, events_sig, jets_sig, constituents_sig, ) # Create datasets ds = helper.create_datasets(*data, *labels) # Concatenate events, jets and constituents respectively ds_events = ConcatDataset([ds["events_train"], ds["events_val"]]) ds_jets = ConcatDataset([ds["jets_train"], ds["jets_val"]]) ds_constituents = ConcatDataset([ds["constituents_train"], ds["constituents_val"]]) ds = { "events": ds_events, "jets": ds_jets, "constituents": ds_constituents, } if verbose: # Print input shapes print("Events - bkg shape: ", events_bkg.shape) print("Jets - bkg shape: ", jets_bkg.shape) print("Constituents - bkg shape: ", constituents_bkg.shape) print("Events - sig shape: ", events_sig.shape) print("Jets - sig shape: ", jets_sig.shape) print("Constituents - sig shape: ", constituents_sig.shape) # Print label shapes print("Events - bkg labels shape: ", events_bkg_label.shape) print("Jets - bkg labels shape: ", jets_bkg_label.shape) print("Constituents - bkg labels shape: ", constituents_bkg_label.shape) print("Events - sig labels shape: ", events_sig_label.shape) print("Jets - sig labels shape: ", jets_sig_label.shape) print("Constituents - sig labels shape: ", constituents_sig_label.shape) # Calculate the input shapes to load the model in_shape = helper.calculate_in_shape(data, config) # Load the model and set to eval mode for inference model = helper.load_model(model_path=model_path, in_shape=in_shape, config=config) model.eval() if verbose: print(f"Model loaded from {model_path}") print(f"Model architecture:\n{model}") print(f"Device used for inference: {device}") print(f"Inputs and model moved to device") # Pushing input data into the torch-DataLoader object and combines into one DataLoader object (a basic wrapper # around several DataLoader objects). print( "Loading data into DataLoader and using batch size of ", config.batch_size ) if config.deterministic_algorithm: if config.verbose: print("Deterministic algorithm is set to True") torch.backends.cudnn.deterministic = True random.seed(0) torch.manual_seed(0) np.random.seed(0) torch.use_deterministic_algorithms(True) g = torch.Generator() g.manual_seed(0) test_dl_list = [ DataLoader( ds, batch_size=config.batch_size, shuffle=False, worker_init_fn=seed_worker, generator=g, drop_last=True, num_workers=config.parallel_workers, ) for ds in [ds["events"], ds["jets"], ds["constituents"]] ] else: test_dl_list = [ DataLoader(ds, batch_size=config.batch_size, shuffle=False, drop_last=True, num_workers=config.parallel_workers,) for ds in [ds["events"], ds["jets"], ds["constituents"]] ] # Unpacking the DataLoader lists test_dl_events, test_dl_jets, test_dl_constituents = test_dl_list if config.model_name == "pj_ensemble": if verbose: print("Model is an ensemble model") else: if config.input_level == "event": test_dl = test_dl_events elif config.input_level == "jet": test_dl = test_dl_jets elif config.input_level == "constituent": test_dl = test_dl_constituents if verbose: print(f"Input data is of {config.input_level} level") # Select Loss Function try: loss_object = helper.get_loss(config.loss_function) loss_fn = loss_object(config=config) if verbose: print(f"Loss Function: {config.loss_function}") except ValueError as e: print(e) # Output Lists test_loss_data = [] reconstructed_data = [] mu_data = [] logvar_data = [] z0_data = [] zk_data = [] log_det_jacobian_data = [] start = time.time() # Registering hooks for activation extraction if config.activation_extraction: hooks = model.store_hooks() if verbose: print(f"Beginning Inference") # Inference parameters = model.parameters() with torch.no_grad(): for idx, batch in enumerate(tqdm(test_dl)): inputs, labels = batch out = helper.call_forward(model, inputs) recon, mu, logvar, ldj, z0, zk = out # Compute the loss losses = loss_fn.calculate( recon=recon, target=inputs, mu=mu, logvar=logvar, parameters=parameters, log_det_jacobian=0, ) test_loss_data.append(losses) reconstructed_data.append(recon.detach().cpu().numpy()) mu_data.append(mu.detach().cpu().numpy()) logvar_data.append(logvar.detach().cpu().numpy()) log_det_jacobian_data.append(ldj.detach().cpu().numpy()) z0_data.append(z0.detach().cpu().numpy()) zk_data.append(zk.detach().cpu().numpy()) end = time.time() # Saving activations values if config.activation_extraction: activations = diagnostics.dict_to_square_matrix(model.get_activations()) model.detach_hooks(hooks) np.save(os.path.join(project_path, "activations.npy"), activations) if verbose: print(f"Training the model took {(end - start) / 60:.3} minutes") # Save all the data save_dir = os.path.join(output_path, "results") np.save( os.path.join(save_dir, "reconstructed_data.npy"), np.array(reconstructed_data), ) np.save( os.path.join(save_dir, "mu_data.npy"), np.array(mu_data), ) np.save( os.path.join(save_dir, "logvar_data.npy"), np.array(logvar_data), ) np.save( os.path.join(save_dir, "z0_data.npy"), np.array(z0_data), ) np.save( os.path.join(save_dir, "zk_data.npy"), np.array(zk_data), ) np.save( os.path.join(save_dir, "log_det_jacobian_data.npy"), np.array(log_det_jacobian_data), ) helper.save_loss_components(loss_data=test_loss_data, component_names=loss_fn.component_names, suffix="test", save_dir=save_dir) return True