Train

This module defines the training process for the neural network model. The GraphForecaster object in forecaster.py is responsible for the forward pass of the model itself. The key-functions in the forecaster that users may want to adapt to their own applications are:

  • advance_input, which defines how the model iterates forward in forecast time

  • _step, where the forward pass of the model happens both during training and validation

AnemoiTrainer in train.py is the object from which the training of the model is controlled. It also contains functions that enable the user to profile the training of the model (profiler.py).

class anemoi.training.train.forecaster.GraphForecaster(*, config: DictConfig, graph_data: HeteroData, statistics: dict, data_indices: IndexCollection, metadata: dict, supporting_arrays: dict)

Bases: LightningModule

Graph neural network forecaster for PyTorch Lightning.

forward(x: Tensor) Tensor

Same as torch.nn.Module.forward().

Parameters:
  • *args – Whatever you decide to pass into the forward method.

  • **kwargs – Keyword arguments are also possible.

Returns:

Your model’s output

static get_loss_function(config: DictConfig, scalars: dict[str, tuple[int | tuple[int, ...] | Tensor]] | None = None, **kwargs) BaseWeightedLoss | ModuleList

Get loss functions from config.

Can be ModuleList if multiple losses are specified.

Parameters:
  • config (DictConfig) – Loss function configuration, should include scalars if scalars are to be added to the loss function.

  • scalars (Union[dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]], None], optional) –

    Scalars which can be added to the loss function. Defaults to None., by default None If a scalar is to be added to the loss, ensure it is in scalars in the loss config E.g.

    If scalars: [‘variable’] is set in the config, and variable in scalars variable will be added to the scalar of the loss function.

  • kwargs (Any) – Additional arguments to pass to the loss function

Returns:

Loss function, or list of metrics

Return type:

Union[BaseWeightedLoss, torch.nn.ModuleList]

Raises:
  • TypeError – If not a subclass of BaseWeightedLoss

  • ValueError – If scalar is not found in valid scalars

training_weights_for_imputed_variables(batch: Tensor) None

Update the loss weights mask for imputed variables.

rollout_step(batch: Tensor, rollout: int | None = None, training_mode: bool = True, validation_mode: bool = False) Generator[tuple[Tensor | None, dict, list], None, None]

Rollout step for the forecaster.

Will run pre_processors on batch, but not post_processors on predictions.

Parameters:
  • batch (torch.Tensor) – Batch to use for rollout

  • rollout (Optional[int], optional) – Number of times to rollout for, by default None If None, will use self.rollout

  • training_mode (bool, optional) – Whether in training mode and to calculate the loss, by default True If False, loss will be None

  • validation_mode (bool, optional) – Whether in validation mode, and to calculate validation metrics, by default False If False, metrics will be empty

Yields:

Generator[tuple[Union[torch.Tensor, None], dict, list], None, None] – Loss value, metrics, and predictions (per step)

Returns:

None

Return type:

None

allgather_batch(batch: Tensor) Tensor

Allgather the batch-shards across the reader group.

Parameters:

batch (torch.Tensor) – Batch-shard of current reader rank

Returns:

Allgathered (full) batch

Return type:

torch.Tensor

calculate_val_metrics(y_pred: Tensor, y: Tensor, rollout_step: int) tuple[dict, list[Tensor]]

Calculate metrics on the validation output.

Parameters:
  • y_pred (torch.Tensor) – Predicted ensemble

  • y (torch.Tensor) – Ground truth (target).

  • rollout_step (int) – Rollout step

Returns:

validation metrics and predictions

Return type:

val_metrics, preds

training_step(batch: Tensor, batch_idx: int) Tensor

Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.

Parameters:
  • batch – The output of your data iterable, normally a DataLoader.

  • batch_idx – The index of this batch.

  • dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)

Returns:

  • Tensor - The loss tensor

  • dict - A dictionary which can include any keys, but must include the key 'loss' in the case of automatic optimization.

  • None - In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.

In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.

Example:

def training_step(self, batch, batch_idx):
    x, y, z = batch
    out = self.encoder(x)
    loss = self.loss(out, x)
    return loss

To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:

def __init__(self):
    super().__init__()
    self.automatic_optimization = False


# Multiple optimizers (e.g.: GANs)
def training_step(self, batch, batch_idx):
    opt1, opt2 = self.optimizers()

    # do training_step with encoder
    ...
    opt1.step()
    # do training_step with decoder
    ...
    opt2.step()

Note

When accumulate_grad_batches > 1, the loss returned here will be automatically normalized by accumulate_grad_batches internally.

lr_scheduler_step(scheduler: CosineLRScheduler, metric: None = None) None

Step the learning rate scheduler by Pytorch Lightning.

Parameters:
  • scheduler (CosineLRScheduler) – Learning rate scheduler object.

  • metric (Optional[Any]) – Metric object for e.g. ReduceLRonPlateau. Default is None.

on_train_epoch_end() None

Called in the training loop at the very end of the epoch.

To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the LightningModule and access them in this hook:

class MyLightningModule(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.training_step_outputs = []

    def training_step(self):
        loss = ...
        self.training_step_outputs.append(loss)
        return loss

    def on_train_epoch_end(self):
        # do something with all training_step outputs, for example:
        epoch_mean = torch.stack(self.training_step_outputs).mean()
        self.log("training_epoch_mean", epoch_mean)
        # free up the memory
        self.training_step_outputs.clear()
validation_step(batch: Tensor, batch_idx: int) None

Calculate the loss over a validation batch using the training loss function.

Parameters:
  • batch (torch.Tensor) – Validation batch

  • batch_idx (int) – Batch inces

Return type:

None

configure_optimizers() tuple[list[Optimizer], list[dict]]

Choose what optimizers and learning-rate schedulers to use in your optimization. Normally you’d need one. But in the case of GANs or similar you might have multiple. Optimization with multiple optimizers only works in the manual optimization mode.

Returns:

Any of these 6 options.

  • Single optimizer.

  • List or Tuple of optimizers.

  • Two lists - The first list has multiple optimizers, and the second has multiple LR schedulers (or multiple lr_scheduler_config).

  • Dictionary, with an "optimizer" key, and (optionally) a "lr_scheduler" key whose value is a single LR scheduler or lr_scheduler_config.

  • None - Fit will run without any optimizer.

The lr_scheduler_config is a dictionary which contains the scheduler and its associated configuration. The default configuration is shown below.

lr_scheduler_config = {
    # REQUIRED: The scheduler instance
    "scheduler": lr_scheduler,
    # The unit of the scheduler's step size, could also be 'step'.
    # 'epoch' updates the scheduler on epoch end whereas 'step'
    # updates it after a optimizer update.
    "interval": "epoch",
    # How many epochs/steps should pass between calls to
    # `scheduler.step()`. 1 corresponds to updating the learning
    # rate after every epoch/step.
    "frequency": 1,
    # Metric to to monitor for schedulers like `ReduceLROnPlateau`
    "monitor": "val_loss",
    # If set to `True`, will enforce that the value specified 'monitor'
    # is available when the scheduler is updated, thus stopping
    # training if not found. If set to `False`, it will only produce a warning
    "strict": True,
    # If using the `LearningRateMonitor` callback to monitor the
    # learning rate progress, this keyword can be used to specify
    # a custom logged name
    "name": None,
}

When there are schedulers in which the .step() method is conditioned on a value, such as the torch.optim.lr_scheduler.ReduceLROnPlateau scheduler, Lightning requires that the lr_scheduler_config contains the keyword "monitor" set to the metric name that the scheduler should be conditioned on.

Metrics can be made available to monitor by simply logging it using self.log('metric_to_track', metric_val) in your LightningModule.

Note

Some things to know:

  • Lightning calls .backward() and .step() automatically in case of automatic optimization.

  • If a learning rate scheduler is specified in configure_optimizers() with key "interval" (default “epoch”) in the scheduler configuration, Lightning will call the scheduler’s .step() method automatically in case of automatic optimization.

  • If you use 16-bit precision (precision=16), Lightning will automatically handle the optimizer.

  • If you use torch.optim.LBFGS, Lightning handles the closure function automatically for you.

  • If you use multiple optimizers, you will have to switch to ‘manual optimization’ mode and step them yourself.

  • If you need to control how often the optimizer steps, override the optimizer_step() hook.

class anemoi.training.train.train.AnemoiTrainer(config: DictConfig)

Bases: object

Utility class for training the model.

property datamodule: AnemoiDatasetsDataModule

DataModule instance and DataSets.

property data_indices: dict

Returns a dictionary of data indices.

This is used to slice the data.

property initial_seed: int

Initial seed for the RNG.

This sets the same initial seed for all ranks. Ranks are re-seeded in the strategy to account for model communication groups.

property graph_data: HeteroData

Graph data.

Creates the graph in all workers.

property model: GraphForecaster

Provide the model instance.

property run_id: str

Unique identifier for the current run.

property wandb_logger: WandbLogger

WandB logger.

property mlflow_logger: MLFlowLogger

Mlflow logger.

property tensorboard_logger: TensorBoardLogger

TensorBoard logger.

property last_checkpoint: str | None

Path to the last checkpoint.

property metadata: dict

Metadata and provenance information.

property profiler: PyTorchProfiler | None

Returns a pytorch profiler object, if profiling is enabled.

property strategy: DDPGroupStrategy

Training strategy.

train() None

Training entry point.