Train

The GraphForecaster and AnemoiTrainer define the training process for the neural network model. While the GraphForecaster defines the LightningModule that defines the model task, the AnemoiTrainer module calls the training function.

Forecaster

The different model tasks are reflected in different forecasters:

  1. Deterministic Forecasting (GraphForecaster)

  2. Ensemble Forecasting (GraphEnsForecaster)

  3. Time Interpolation (GraphInterpolator)

The GraphForecaster object in forecaster.py is responsible for the forward pass of the model itself. The key-functions in the forecaster that users may want to adapt to their own applications are:

  • advance_input, which defines how the model iterates forward in forecast time

  • _step, where the forward pass of the model happens both during training and validation

AnemoiTrainer in train.py is the object from which the training of the model is controlled. It also contains functions that enable the user to profile the training of the model (profiler.py).

class anemoi.training.train.forecaster.forecaster.GraphForecaster(*, config: BaseSchema, graph_data: HeteroData, truncation_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict, supporting_arrays: dict)

Bases: LightningModule

Graph neural network forecaster for PyTorch Lightning.

forward(x: Tensor) Tensor

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

static get_loss_function(config: DictConfig, scalars: dict[str, tuple[int | tuple[int, ...] | Tensor]] | None = None, **kwargs) BaseWeightedLoss | ModuleList

Get loss functions from config.

Can be ModuleList if multiple losses are specified.

Parameters:
  • config (DictConfig) – Loss function configuration, should include scalars if scalars are to be added to the loss function.

  • scalars (dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]] | None) – Scalars which can be added to the loss function. Defaults to None. If a ‘scalar’ is to be added to the loss, ensure it is in ‘scalars’ in the loss config. For instance, if ‘scalars: [‘variable’]’ is set in the config, and ‘variable’ in ‘scalars’ ‘variable’ will be added to the scalar of the loss function.

  • kwargs (Any) – Additional arguments to pass to the loss function

Returns:

The loss function to use for training

Return type:

BaseWeightedLoss | torch.nn.ModuleList

Raises:
  • TypeError – If not a subclass of ‘BaseWeightedLoss’

  • ValueError – If scalar is not found in valid scalars

training_weights_for_imputed_variables(batch: Tensor) None

Update the loss weights mask for imputed variables.

rollout_step(batch: torch.Tensor, rollout: int | None = None, training_mode: bool = True, validation_mode: bool = False) Generator[tuple[torch.Tensor | None, dict, list], None, None]

Rollout step for the forecaster.

Will run pre_processors on batch, but not post_processors on predictions.

Parameters:
  • batch (torch.Tensor) – Batch to use for rollout

  • rollout (Optional[int], optional) – Number of times to rollout for, by default None If None, will use self.rollout

  • training_mode (bool, optional) – Whether in training mode and to calculate the loss, by default True If False, loss will be None

  • validation_mode (bool, optional) – Whether in validation mode, and to calculate validation metrics, by default False If False, metrics will be empty

Yields:

Generator[tuple[Union[torch.Tensor, None], dict, list], None, None] – Loss value, metrics, and predictions (per step)

allgather_batch(batch: Tensor) Tensor

Allgather the batch-shards across the reader group.

Parameters:

batch (torch.Tensor) – Batch-shard of current reader rank

Returns:

Allgathered (full) batch

Return type:

torch.Tensor

calculate_val_metrics(y_pred: Tensor, y: Tensor, rollout_step: int) tuple[dict, list[Tensor]]

Calculate metrics on the validation output.

Parameters:
  • y_pred (torch.Tensor) – Predicted ensemble

  • y (torch.Tensor) – Ground truth (target).

  • rollout_step (int) – Rollout step

Returns:

validation metrics and predictions

Return type:

val_metrics, preds

training_step(batch: Tensor, batch_idx: int) Tensor

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

lr_scheduler_step(scheduler: CosineLRScheduler, metric: None = None) None

Step the learning rate scheduler by Pytorch Lightning.

Parameters:
  • scheduler (CosineLRScheduler) – Learning rate scheduler object.

  • metric (Optional[Any]) – Metric object for e.g. ReduceLRonPlateau. Default is None.

on_train_epoch_end() None

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(self.training_step_outputs).mean()
        self.log("training_epoch_mean", epoch_mean)
        # free up the memory
        self.training_step_outputs.clear()
validation_step(batch: Tensor, batch_idx: int) None

Calculate the loss over a validation batch using the training loss function.

Parameters:
  • batch (torch.Tensor) – Validation batch

  • batch_idx (int) – Batch inces

configure_optimizers() tuple[list[Optimizer], list[dict]]

Configure the optimizers and learning rate scheduler.

Returns:

List of optimizers and list of dictionaries containing the learning rate scheduler

Return type:

tuple[list[torch.optim.Optimizer], list[dict]]

class anemoi.training.train.forecaster.ensforecaster.GraphEnsForecaster(*, config: DictConfig, graph_data: HeteroData, truncation_data: dict, statistics: dict, data_indices: dict, metadata: dict, supporting_arrays: dict)

Bases: GraphForecaster

Graph neural network forecaster for ensembles for PyTorch Lightning.

forward(x: Tensor, fcstep: int) Tensor

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

gather_and_compute_loss(y_pred: torch.Tensor, y: torch.Tensor, loss: torch.nn.Module, nens_per_device: int, ens_comm_group_size: int, ens_comm_group: ProcessGroup, model_comm_group: ProcessGroup, return_pred_ens: bool = False) tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]

Gather the ensemble members from all devices in my group.

Eliminate duplicates (if any) and compute the loss.

Parameters:
  • y_pred – torch.Tensor Predicted state tensor, calculated on self.device

  • y – torch.Tensor Ground truth

  • loss – torch.nn.Module Loss function

  • nens_per_device – int Number of ensemble members per device

  • ens_comm_group_size – int Size of ensemble communication group

  • ens_comm_group – int Process ensemble group

  • model_comm_group – int Process model group

  • return_pred_ens – bool Validation flag: if True, we return the predicted ensemble (post-gather)

Returns:

  • loss_inc – Loss

  • y_pred_ens – Predictions if validation mode

rollout_step(batch: torch.Tensor, rollout: int | None = None, training_mode: bool = True, validation_mode: bool = False) Generator[tuple[torch.Tensor | None, dict, list], None, None]

Rollout step for the forecaster.

Will run pre_processors on batch, but not post_processors on predictions.

Parameters:
  • batch (torch.Tensor) – Batch to use for rollout

  • rollout (Optional[int], optional) – Number of times to rollout for, by default None If None, will use self.rollout

  • training_mode (bool, optional) – Whether in training mode and to calculate the loss, by default True If False, loss will be None

  • validation_mode (bool, optional) – Whether in validation mode, and to calculate validation metrics, by default False If False, metrics will be empty

Yields:

Generator[tuple[Union[torch.Tensor, None], dict, list], None, None] – Loss value, metrics, and predictions (per step)

Returns:

None

Return type:

None

training_step(batch: tuple[Tensor, ...], batch_idx: int) Tensor | dict

Run one training step.

Parameters:
  • batch – tuple Batch data. tuple of length 1 or 2. batch[0]: analysis, shape (bs, multi_step + rollout, nvar, latlon) batch[1] (optional with ensemble): EDA perturbations, shape (multi_step, nens_per_device, nvar, latlon)

  • batch_idx – int Training batch index

Returns:

Training loss

Return type:

train_loss

validation_step(batch: tuple[Tensor, ...], batch_idx: int) tuple[Tensor, Tensor]

Perform a validation step.

Parameters:
  • batch (tuple) – Batch data. tuple of length 1 or 2. batch[0]: analysis, shape (bs, multi_step + rollout, nvar, latlon) batch[1] (optional): EDA perturbations, shape (nens_per_device, multi_step, nvar, latlon)

  • batch_idx (int) – Validation batch index

Returns:

Tuple containing the validation loss, the predictions, and the ensemble initial conditions

Return type:

tuple[torch.Tensor, torch.Tensor]

class anemoi.training.train.forecaster.interpolator.GraphInterpolator(*, config: DictConfig, graph_data: HeteroData, statistics: dict, data_indices: IndexCollection, metadata: dict, supporting_arrays: dict)

Bases: GraphForecaster

Graph neural network interpolator for PyTorch Lightning.

forward(x: Tensor, target_forcing: Tensor) Tensor

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

Trainer

The AnemoiTrainer object in train.py is responsible for calling the training function.

class anemoi.training.train.train.AnemoiTrainer(config: DictConfig)

Bases: object

Utility class for training the model.

property datamodule: Any

DataModule instance and DataSets.

property data_indices: dict

Returns a dictionary of data indices.

This is used to slice the data.

property initial_seed: int

Initial seed for the RNG.

This sets the same initial seed for all ranks. Ranks are re-seeded in the strategy to account for model communication groups.

property graph_data: HeteroData

Graph data.

Creates the graph in all workers.

property truncation_data: dict

Truncation data.

Loads truncation data.

property model: LightningModule

Provide the model instance.

property run_id: str

Unique identifier for the current run.

property wandb_logger: WandbLogger

WandB logger.

property mlflow_logger: MLFlowLogger

Mlflow logger.

property tensorboard_logger: TensorBoardLogger

TensorBoard logger.

property last_checkpoint: str | None

Path to the last checkpoint.

property metadata: dict

Metadata and provenance information.

property profiler: PyTorchProfiler | None

Returns a pytorch profiler object, if profiling is enabled.

train() None

Training entry point.