# Copyright 2022 Baler Contributors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import random
import time
import sys
import numpy as np
from tqdm.rich import tqdm
import warnings
from tqdm import TqdmExperimentalWarning
from torch.nn import functional as F
import torch
from torch.utils.data import DataLoader
from ..utils import helper, loss, diagnostics
warnings.filterwarnings("ignore", category=TqdmExperimentalWarning)
[docs]
def fit(
config,
model,
dataloader,
loss_fn,
reg_param,
optimizer,
):
"""This function trains the model on the train set. It computes the losses and does the backwards propagation, and updates the optimizer as well.
Args:
config (dataClass): Base class selecting user inputs
model (modelObject): The model you wish to train
train_dl (torch.DataLoader): Defines the batched data which the model is trained on
loss (lossObject): Defines the loss function used to train the model
reg_param (float): Determines proportionality constant to balance different components of the loss.
optimizer (torch.optim): Chooses optimizer for gradient descent.
Returns:
list, model object: Training losses, Epoch_loss and trained model
"""
# Extract model parameters
parameters = model.parameters()
model.train()
running_loss = 0.0
for idx, batch in enumerate(tqdm(dataloader)):
inputs, labels = batch
# Set previous gradients to zero
optimizer.zero_grad()
# Compute the predicted outputs from the input data
out = helper.call_forward(model, inputs)
recon, mu, logvar, ldj, z0, zk = out
# Compute the loss
losses = loss_fn.calculate(
recon=recon,
target=inputs,
mu=mu,
logvar=logvar,
parameters=parameters,
log_det_jacobian=0,
)
loss, *_ = losses
# Compute the loss-gradient with
loss.backward()
# Update the optimizer
optimizer.step()
running_loss += loss
epoch_loss = running_loss / (idx + 1)
print(f"# Training Loss: {epoch_loss:.6f}")
return losses, epoch_loss, model
[docs]
def validate(config, model, dataloader, loss_fn, reg_param):
"""Function used to validate the training. Not necessary for doing compression, but gives a good indication of wether the model selected is a good fit or not.
Args:
model (modelObject): Defines the model one wants to validate. The model used here is passed directly from `fit()`.
test_dl (torch.DataLoader): Defines the batched data which the model is validated on
model_children (list): List of model parameters
reg_param (float): Determines proportionality constant to balance different components of the loss.
Returns:
float: Validation loss
"""
# Extract model parameters
parameters = model.parameters()
model.eval()
running_loss = 0.0
with torch.no_grad():
for idx, batch in enumerate(tqdm(dataloader)):
inputs, labels = batch
out = helper.call_forward(model, inputs)
recon, mu, logvar, ldj, z0, zk = out
# Compute the loss
losses = loss_fn.calculate(
recon=recon,
target=inputs,
mu=mu,
logvar=logvar,
parameters=parameters,
log_det_jacobian=0,
)
loss, *_ = losses
running_loss += loss
epoch_loss = running_loss / (idx + 1)
print(f"# Validation Loss: {epoch_loss:.6f}")
return losses, epoch_loss
[docs]
def seed_worker(worker_id):
"""PyTorch implementation to fix the seeds
Args:
worker_id ():
"""
worker_seed = torch.initial_seed() % 2**32
np.random.seed(worker_seed)
random.seed(worker_seed)
[docs]
def train(
events_train,
jets_train,
constituents_train,
events_val,
jets_val,
constituents_val,
output_path,
config,
verbose: bool = False,
):
"""Does the entire training loop by calling the `fit()` and `validate()`. Appart from this, this is the main function where the data is converted
to the correct type for it to be trained, via `torch.Tensor()`. Furthermore, the batching is also done here, based on `config.batch_size`,
and it is the `torch.utils.data.DataLoader` doing the splitting.
Applying either `EarlyStopping` or `LR Scheduler` is also done here, all based on their respective `config` arguments.
For reproducibility, the seeds can also be fixed in this function.
Args:
model (modelObject): The model you wish to train
data (Tuple): Tuple containing the training and validation data
project_path (string): Path to the project directory
config (dataClass): Base class selecting user inputs
Returns:
modelObject: fully trained model ready to perform compression and decompression
"""
if verbose:
print("Events - Training set size: ", events_train.size(0))
print("Events - Validation set size: ", events_val.size(0))
print("Jets - Training set size: ", jets_train.size(0))
print("Jets - Validation set size: ", jets_val.size(0))
print("Constituents - Training set size: ", constituents_train.size(0))
print("Constituents - Validation set size: ", constituents_val.size(0))
# Get the device and move tensors to the device
device = helper.get_device()
labeled_data = (
events_train,
jets_train,
constituents_train,
events_val,
jets_val,
constituents_val,
)
(
events_train,
jets_train,
constituents_train,
events_val,
jets_val,
constituents_val,
) = [
x.to(device)
for x in labeled_data
]
# Split data and labels
if verbose:
print("Splitting data and labels")
data, labels = helper.data_label_split(labeled_data)
# Reshape tensors to pass to conv layers
(
events_train,
jets_train,
constituents_train,
events_val,
jets_val,
constituents_val,
) = data
(
events_train_label,
jets_train_label,
constituents_train_label,
events_val_label,
jets_val_label,
constituents_val_label,
) = labels
# Reshape tensors to pass to conv layers
if "ConvVAE" in config.model_name:
(
events_train,
jets_train,
constituents_train,
events_val,
jets_val,
constituents_val,
) = [
x.unsqueeze(1).float()
for x in [events_train, jets_train, constituents_train, events_val, jets_val, constituents_val]
]
data = (
events_train,
jets_train,
constituents_train,
events_val,
jets_val,
constituents_val,
)
# Create datasets
ds = helper.create_datasets(*data, *labels)
if verbose:
# Print input shapes
print("Events - Training set shape: ", events_train.shape)
print("Events - Validation set shape: ", events_val.shape)
print("Jets - Training set shape: ", jets_train.shape)
print("Jets - Validation set shape: ", jets_val.shape)
print("Constituents - Training set shape: ", constituents_train.shape)
print("Constituents - Validation set shape: ", constituents_val.shape)
# Print label shapes
print("Events - Training set labels shape: ", events_train_label.shape)
print("Events - Validation set labels shape: ", events_val_label.shape)
print("Jets - Training set labels shape: ", jets_train_label.shape)
print("Jets - Validation set labels shape: ", jets_val_label.shape)
print("Constituents - Training set labels shape: ", constituents_train_label.shape)
print("Constituents - Validation set labels shape: ", constituents_val_label.shape)
# Calculate the input shapes to initialize the model
in_shape = helper.calculate_in_shape(data, config)
# Instantiate and Initialize the model
if verbose:
print(f"Intitalizing Model with Latent Size - {config.latent_space_size}")
model = helper.model_init(in_shape, config)
if verbose:
if config.model_init == "xavier":
print("Model initialized using Xavier initialization")
else:
print("Model initialized using default PyTorch initialization")
print(f"Model architecture:\n{model}")
model = model.to(device)
if verbose:
print(f"Device used for training: {device}")
print(f"Inputs and model moved to device")
# Pushing input data into the torch-DataLoader object and combines into one DataLoader object (a basic wrapper
# around several DataLoader objects).
if verbose:
print(
"Loading data into DataLoader and using batch size of ", config.batch_size
)
if config.deterministic_algorithm:
if config.verbose:
print("Deterministic algorithm is set to True")
torch.backends.cudnn.deterministic = True
random.seed(0)
torch.manual_seed(0)
np.random.seed(0)
torch.use_deterministic_algorithms(True)
g = torch.Generator()
g.manual_seed(0)
train_dl_list = [
DataLoader(
ds,
batch_size=config.batch_size,
shuffle=False,
worker_init_fn=seed_worker,
generator=g,
drop_last=True,
num_workers=config.parallel_workers,
)
for ds in [ds["events_train"], ds["jets_train"], ds["constituents_train"]]
]
valid_dl_list = [
DataLoader(
ds,
batch_size=config.batch_size,
shuffle=False,
worker_init_fn=seed_worker,
generator=g,
drop_last=True,
num_workers=config.parallel_workers,
)
for ds in [ds["events_val"], ds["jets_val"], ds["constituents_val"]]
]
else:
train_dl_list = [
DataLoader(ds, batch_size=config.batch_size, shuffle=False, drop_last=True, num_workers=config.parallel_workers,)
for ds in [ds["events_train"], ds["jets_train"], ds["constituents_train"]]
]
valid_dl_list = [
DataLoader(ds, batch_size=config.batch_size, shuffle=False, drop_last=True, num_workers=config.parallel_workers,)
for ds in [ds["events_val"], ds["jets_val"], ds["constituents_val"]]
]
# Unpacking the DataLoader lists
train_dl_events, train_dl_jets, train_dl_constituents = train_dl_list
val_dl_events, val_dl_jets, val_dl_constituents = valid_dl_list
if config.model_name == "pj_ensemble":
if verbose:
print("Model is an ensemble model")
else:
if config.input_level == "event":
train_dl = train_dl_events
valid_dl = val_dl_events
elif config.input_level == "jet":
train_dl = train_dl_jets
valid_dl = val_dl_jets
elif config.input_level == "constituent":
train_dl = train_dl_constituents
valid_dl = val_dl_constituents
if verbose:
print(f"Input data is of {config.input_level} level")
# Select Loss Function
try:
loss_object = helper.get_loss(config.loss_function)
loss_fn = loss_object(config=config)
if verbose:
print(f"Loss Function: {config.loss_function}")
except ValueError as e:
print(e)
# Select Optimizer
try:
optimizer = helper.get_optimizer(
config.optimizer, model.parameters(), lr=config.lr
)
if verbose:
print(f"Optimizer: {config.optimizer}")
except ValueError as e:
print(e)
# Activate early stopping
if config.early_stopping:
if verbose:
print(
"Early stopping is activated with patience of ",
config.early_stopping_patience,
)
early_stopper = helper.EarlyStopping(
patience=config.early_stopping_patience, min_delta=config.min_delta
) # Changes to patience & min_delta can be made in configs
# Activate LR Scheduler
if config.lr_scheduler:
if verbose:
print(
"Learning rate scheduler is activated with patience of ",
config.lr_scheduler_patience,
)
lr_scheduler = helper.LRScheduler(
optimizer=optimizer, patience=config.lr_scheduler_patience
)
# Training and Validation of the model
train_loss_data = []
val_loss_data = []
train_loss = []
val_loss = []
start = time.time()
# Registering hooks for activation extraction
if config.activation_extraction:
hooks = model.store_hooks()
if verbose:
print(f"Beginning training for {config.epochs} epochs")
for epoch in range(config.epochs):
print(f"Epoch {epoch + 1} / {config.epochs}")
train_losses, train_epoch_loss, model = fit(
config=config,
model=model,
dataloader=train_dl,
loss_fn=loss_fn,
reg_param=config.reg_param,
optimizer=optimizer,
)
train_loss.append(train_epoch_loss.detach().cpu().numpy())
train_loss_data.append(train_losses)
if 1 - config.train_size:
val_losses, val_epoch_loss = validate(
config=config,
model=model,
dataloader=valid_dl,
loss_fn=loss_fn,
reg_param=config.reg_param,
)
val_loss.append(val_epoch_loss.detach().cpu().numpy())
val_loss_data.append(val_losses)
else:
val_epoch_loss = train_epoch_loss
val_losses = train_losses
val_loss.append(val_epoch_loss)
val_loss_data.append(val_losses)
# Implementing LR Scheduler
if config.lr_scheduler:
lr_scheduler(val_epoch_loss)
## Implementation to save models & values after every N config.epochs, where N is stored in 'config.intermittent_saving_patience':
if config.intermittent_model_saving:
if epoch % config.intermittent_saving_patience == 0:
path = os.path.join(output_path, "models", f"model_{epoch}.pt")
helper.model_saver(model, path)
# Implementing Early Stopping
if config.early_stopping:
early_stopper(val_epoch_loss)
if early_stopper.early_stop:
if verbose:
print("Early stopping activated at epoch ", epoch)
break
end = time.time()
# Saving activations values
if config.activation_extraction:
activations = diagnostics.dict_to_square_matrix(model.get_activations())
model.detach_hooks(hooks)
np.save(os.path.join(project_path, "activations.npy"), activations)
if verbose:
print(f"Training the model took {(end - start) / 60:.3} minutes")
# Save loss data
save_dir = os.path.join(output_path, "results")
np.save(
os.path.join(save_dir, "train_epoch_loss_data.npy"),
np.array(train_loss),
)
np.save(
os.path.join(save_dir, "val_epoch_loss_data.npy"),
np.array(val_loss),
)
helper.save_loss_components(loss_data=train_loss_data, component_names=loss_fn.component_names, suffix="train", save_dir=save_dir)
helper.save_loss_components(loss_data=val_loss_data, component_names=loss_fn.component_names, suffix="val", save_dir=save_dir)
if verbose:
print("Loss data saved to path: ", save_dir)
return model