Train
The GraphForecaster
and AnemoiTrainer
define the training
process for the neural network model. While the GraphForecaster
defines the LightningModule
that defines the model task, the
AnemoiTrainer
module calls the training function.
Forecaster
The different model tasks are reflected in different forecasters:
Deterministic Forecasting (GraphForecaster)
Ensemble Forecasting (GraphEnsForecaster)
Time Interpolation (GraphInterpolator)
The GraphForecaster
object in forecaster.py
is responsible for
the forward pass of the model itself. The key-functions in the
forecaster that users may want to adapt to their own applications are:
advance_input
, which defines how the model iterates forward in forecast time_step
, where the forward pass of the model happens both during training and validation
AnemoiTrainer
in train.py
is the object from which the training
of the model is controlled. It also contains functions that enable the
user to profile the training of the model (profiler.py
).
- class anemoi.training.train.forecaster.forecaster.GraphForecaster(*, config: BaseSchema, graph_data: HeteroData, truncation_data: dict, statistics: dict, data_indices: IndexCollection, metadata: dict, supporting_arrays: dict)
Bases:
LightningModule
Graph neural network forecaster for PyTorch Lightning.
- forward(x: Tensor) Tensor
Same as
torch.nn.Module.forward()
.- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- static get_loss_function(config: DictConfig, scalars: dict[str, tuple[int | tuple[int, ...] | Tensor]] | None = None, **kwargs) BaseWeightedLoss | ModuleList
Get loss functions from config.
Can be ModuleList if multiple losses are specified.
- Parameters:
config (DictConfig) – Loss function configuration, should include scalars if scalars are to be added to the loss function.
scalars (dict[str, tuple[Union[int, tuple[int, ...], torch.Tensor]]] | None) – Scalars which can be added to the loss function. Defaults to None. If a ‘scalar’ is to be added to the loss, ensure it is in ‘scalars’ in the loss config. For instance, if ‘scalars: [‘variable’]’ is set in the config, and ‘variable’ in ‘scalars’ ‘variable’ will be added to the scalar of the loss function.
kwargs (Any) – Additional arguments to pass to the loss function
- Returns:
The loss function to use for training
- Return type:
BaseWeightedLoss | torch.nn.ModuleList
- Raises:
TypeError – If not a subclass of ‘BaseWeightedLoss’
ValueError – If scalar is not found in valid scalars
- training_weights_for_imputed_variables(batch: Tensor) None
Update the loss weights mask for imputed variables.
- rollout_step(batch: torch.Tensor, rollout: int | None = None, training_mode: bool = True, validation_mode: bool = False) Generator[tuple[torch.Tensor | None, dict, list], None, None]
Rollout step for the forecaster.
Will run pre_processors on batch, but not post_processors on predictions.
- Parameters:
batch (torch.Tensor) – Batch to use for rollout
rollout (Optional[int], optional) – Number of times to rollout for, by default None If None, will use self.rollout
training_mode (bool, optional) – Whether in training mode and to calculate the loss, by default True If False, loss will be None
validation_mode (bool, optional) – Whether in validation mode, and to calculate validation metrics, by default False If False, metrics will be empty
- Yields:
Generator[tuple[Union[torch.Tensor, None], dict, list], None, None] – Loss value, metrics, and predictions (per step)
- allgather_batch(batch: Tensor) Tensor
Allgather the batch-shards across the reader group.
- Parameters:
batch (torch.Tensor) – Batch-shard of current reader rank
- Returns:
Allgathered (full) batch
- Return type:
torch.Tensor
- calculate_val_metrics(y_pred: Tensor, y: Tensor, rollout_step: int) tuple[dict, list[Tensor]]
Calculate metrics on the validation output.
- Parameters:
y_pred (torch.Tensor) – Predicted ensemble
y (torch.Tensor) – Ground truth (target).
rollout_step (int) – Rollout step
- Returns:
validation metrics and predictions
- Return type:
val_metrics, preds
- training_step(batch: Tensor, batch_idx: int) Tensor
Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger.
- Parameters:
batch – The output of your data iterable, normally a
DataLoader
.batch_idx – The index of this batch.
dataloader_idx – The index of the dataloader that produced this batch. (only if multiple dataloaders used)
- Returns:
Tensor
- The loss tensordict
- A dictionary which can include any keys, but must include the key'loss'
in the case of automatic optimization.None
- In automatic optimization, this will skip to the next batch (but is not supported for multi-GPU, TPU, or DeepSpeed). For manual optimization, this has no special meaning, as returning the loss is not required.
In this step you’d normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something model specific.
Example:
def training_step(self, batch, batch_idx): x, y, z = batch out = self.encoder(x) loss = self.loss(out, x) return loss
To use multiple optimizers, you can switch to ‘manual optimization’ and control their stepping:
def __init__(self): super().__init__() self.automatic_optimization = False # Multiple optimizers (e.g.: GANs) def training_step(self, batch, batch_idx): opt1, opt2 = self.optimizers() # do training_step with encoder ... opt1.step() # do training_step with decoder ... opt2.step()
Note
When
accumulate_grad_batches
> 1, the loss returned here will be automatically normalized byaccumulate_grad_batches
internally.
- lr_scheduler_step(scheduler: CosineLRScheduler, metric: None = None) None
Step the learning rate scheduler by Pytorch Lightning.
- Parameters:
scheduler (CosineLRScheduler) – Learning rate scheduler object.
metric (Optional[Any]) – Metric object for e.g. ReduceLRonPlateau. Default is None.
- on_train_epoch_end() None
Called in the training loop at the very end of the epoch.
To access all batch outputs at the end of the epoch, you can cache step outputs as an attribute of the
LightningModule
and access them in this hook:class MyLightningModule(L.LightningModule): def __init__(self): super().__init__() self.training_step_outputs = [] def training_step(self): loss = ... self.training_step_outputs.append(loss) return loss def on_train_epoch_end(self): # do something with all training_step outputs, for example: epoch_mean = torch.stack(self.training_step_outputs).mean() self.log("training_epoch_mean", epoch_mean) # free up the memory self.training_step_outputs.clear()
- class anemoi.training.train.forecaster.ensforecaster.GraphEnsForecaster(*, config: DictConfig, graph_data: HeteroData, truncation_data: dict, statistics: dict, data_indices: dict, metadata: dict, supporting_arrays: dict)
Bases:
GraphForecaster
Graph neural network forecaster for ensembles for PyTorch Lightning.
- forward(x: Tensor, fcstep: int) Tensor
Same as
torch.nn.Module.forward()
.- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
- gather_and_compute_loss(y_pred: torch.Tensor, y: torch.Tensor, loss: torch.nn.Module, nens_per_device: int, ens_comm_group_size: int, ens_comm_group: ProcessGroup, model_comm_group: ProcessGroup, return_pred_ens: bool = False) tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]
Gather the ensemble members from all devices in my group.
Eliminate duplicates (if any) and compute the loss.
- Parameters:
y_pred – torch.Tensor Predicted state tensor, calculated on self.device
y – torch.Tensor Ground truth
loss – torch.nn.Module Loss function
nens_per_device – int Number of ensemble members per device
ens_comm_group_size – int Size of ensemble communication group
ens_comm_group – int Process ensemble group
model_comm_group – int Process model group
return_pred_ens – bool Validation flag: if True, we return the predicted ensemble (post-gather)
- Returns:
loss_inc – Loss
y_pred_ens – Predictions if validation mode
- rollout_step(batch: torch.Tensor, rollout: int | None = None, training_mode: bool = True, validation_mode: bool = False) Generator[tuple[torch.Tensor | None, dict, list], None, None]
Rollout step for the forecaster.
Will run pre_processors on batch, but not post_processors on predictions.
- Parameters:
batch (torch.Tensor) – Batch to use for rollout
rollout (Optional[int], optional) – Number of times to rollout for, by default None If None, will use self.rollout
training_mode (bool, optional) – Whether in training mode and to calculate the loss, by default True If False, loss will be None
validation_mode (bool, optional) – Whether in validation mode, and to calculate validation metrics, by default False If False, metrics will be empty
- Yields:
Generator[tuple[Union[torch.Tensor, None], dict, list], None, None] – Loss value, metrics, and predictions (per step)
- Returns:
None
- Return type:
None
- training_step(batch: tuple[Tensor, ...], batch_idx: int) Tensor | dict
Run one training step.
- Parameters:
batch – tuple Batch data. tuple of length 1 or 2. batch[0]: analysis, shape (bs, multi_step + rollout, nvar, latlon) batch[1] (optional with ensemble): EDA perturbations, shape (multi_step, nens_per_device, nvar, latlon)
batch_idx – int Training batch index
- Returns:
Training loss
- Return type:
train_loss
- class anemoi.training.train.forecaster.interpolator.GraphInterpolator(*, config: DictConfig, graph_data: HeteroData, statistics: dict, data_indices: IndexCollection, metadata: dict, supporting_arrays: dict)
Bases:
GraphForecaster
Graph neural network interpolator for PyTorch Lightning.
- forward(x: Tensor, target_forcing: Tensor) Tensor
Same as
torch.nn.Module.forward()
.- Parameters:
*args – Whatever you decide to pass into the forward method.
**kwargs – Keyword arguments are also possible.
- Returns:
Your model’s output
Trainer
The AnemoiTrainer
object in train.py
is responsible for calling
the training function.
- class anemoi.training.train.train.AnemoiTrainer(config: DictConfig)
Bases:
object
Utility class for training the model.
- property initial_seed: int
Initial seed for the RNG.
This sets the same initial seed for all ranks. Ranks are re-seeded in the strategy to account for model communication groups.
- property graph_data: HeteroData
Graph data.
Creates the graph in all workers.
- property model: LightningModule
Provide the model instance.
- property wandb_logger: WandbLogger
WandB logger.
- property mlflow_logger: MLFlowLogger
Mlflow logger.
- property tensorboard_logger: TensorBoardLogger
TensorBoard logger.