Diagnostics

The diagnostics module in anemoi-training is used to monitor progress during training. It is split into two parts:

  1. tracking training to a standard machine learning tracking tool. This monitors the training and validation losses and uploads the plots created by the callbacks.

  2. a series of callbacks, evaluated on the validation dataset, including plots of example forecasts and power spectra plots;

Trackers

By default, anemoi-training uses MLFlow tracker, but it includes functionality to use both Weights & Biases and Tensorboard.

Callbacks

The callbacks can also be used to evaluate forecasts over longer rollouts beyond the forecast time that the model is trained on. The number of rollout steps for verification (or forecast iteration steps) is set using config.dataloader.validation_rollout = *num_of_rollout_steps*.

Callbacks are configured in the config file under the config.diagnostics key.

For regular callbacks, they can be provided as a list of dictionaries underneath the config.diagnostics.callbacks key. Each dictionary must have a _target key which is used by hydra to instantiate the callback, any other kwarg is passed to the callback’s constructor.

callbacks:
   - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
   rollout: ${dataloader.validation_rollout}
   frequency: 20

Plotting callbacks are configured in a similar way, but they are specified underneath the config.diagnostics.plot.callbacks key.

This is done to ensure seperation and ease of configuration between experiments.

config.diagnostics.plot is a broader config file specifying the parameters to plot, as well as the plotting frequency, and asynchronosity.

Setting config.diagnostics.plot.asynchronous, means that the model training doesn’t stop whilst the callbacks are being evaluated. This is useful for large models where the plotting can take a long time. The plotting module uses asynchronous callbacks via asyncio and concurrent.futures.ThreadPoolExecutor to handle plotting tasks without blocking the main application. A dedicated event loop runs in a separate background thread, allowing plotting tasks to be offloaded to worker threads. This setup keeps the main thread responsive, handling plot-related tasks asynchronously and efficiently in the background.

There is an additional flag in the plotting callbacks to control the rendering method for geospatial plots, offering a trade-off between performance and detail. When datashader is set to True, Datashader is used for rendering, which accelerates plotting through efficient hexbining, particularly useful for large datasets. This approach can produce smoother-looking plots due to the aggregation of data points. If datashader is set to False, matplotlib.scatter is used, which provides sharper and more detailed visuals but may be slower for large datasets.

Note - this asynchronous behaviour is only available for the plotting callbacks.

plot:
   asynchronous: True # Whether to plot asynchronously
   datashader: True # Whether to use datashader for plotting (faster)
   frequency: # Frequency of the plotting
   batch: 750
   epoch: 5

   # Parameters to plot
      parameters:
      - z_500
      - t_850
      - u_850

      # Sample index
      sample_idx: 0

      # Precipitation and related fields
      precip_and_related_fields: [tp, cp]

      callbacks:
      - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
         # group parameters by categories when visualizing contributions to the loss
         # one-parameter groups are possible to highlight individual parameters
         parameter_groups:
            moisture: [tp, cp, tcw]
            sfc_wind: [10u, 10v]
      - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
         sample_idx: ${diagnostics.plot.sample_idx}
         per_sample : 6
         parameters: ${diagnostics.plot.parameters}

Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure:

class anemoi.training.diagnostics.callbacks.checkpoint.AnemoiCheckpoint(config: OmegaConf, **kwargs: dict)

Bases: ModelCheckpoint

A checkpoint callback that saves the model after every validation epoch.

on_train_end(trainer: pl.Trainer, pl_module: pl.LightningModule) None

Save the last checkpoint at the end of training.

If the candidates aren’t better than the last checkpoint, then no checkpoints are saved. Note - this method if triggered when using max_epochs, it won’t save any checkpoints since the monitor candidates won’t show any changes with regard the the ‘on_train_epoch_end’ hook.

class anemoi.training.diagnostics.callbacks.evaluation.RolloutEval(config: OmegaConf, rollout: int, every_n_batches: int)

Bases: Callback

Evaluates the model performance over a (longer) rollout window.

on_validation_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, outputs: list, batch: torch.Tensor, batch_idx: int) None

Called when the validation batch ends.

class anemoi.training.diagnostics.callbacks.optimiser.LearningRateMonitor(config: DictConfig, logging_interval: str = 'step', log_momentum: bool = False)

Bases: LearningRateMonitor

Provide LearningRateMonitor from pytorch_lightning as a callback.

class anemoi.training.diagnostics.callbacks.optimiser.StochasticWeightAveraging(config: DictConfig, swa_lrs: int | None = None, swa_epoch_start: int | None = None, annealing_epochs: int | None = None, annealing_strategy: str | None = None, device: str | None = None, **kwargs)

Bases: StochasticWeightAveraging

Provide StochasticWeightAveraging from pytorch_lightning as a callback.

class anemoi.training.diagnostics.callbacks.plot.BasePlotCallback(config: BaseSchema)

Bases: Callback, ABC

Factory for creating a callback that plots data to Experiment Logging.

start_event_loop() None

Start the event loop in a separate thread.

teardown(trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) None

Teardown the callback.

async submit_plot(trainer: pl.Trainer, *args: Any, **kwargs: Any) None

Async function or coroutine to schedule the plot function.

class anemoi.training.diagnostics.callbacks.plot.BasePerBatchPlotCallback(config: OmegaConf, every_n_batches: int | None = None)

Bases: BasePlotCallback

Base Callback for plotting at the end of each batch.

on_validation_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, output: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, **kwargs) None

Called when the validation batch ends.

class anemoi.training.diagnostics.callbacks.plot.BasePerEpochPlotCallback(config: OmegaConf, every_n_epochs: int | None = None)

Bases: BasePlotCallback

Base Callback for plotting at the end of each epoch.

on_validation_epoch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, **kwargs) None

Called when the val epoch ends.

class anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots(config: OmegaConf, rollout: list[int], sample_idx: int, parameters: list[str], video_rollout: int = 0, accumulation_levels_plot: list[float] | None = None, colormaps: dict[str, Colormap] | None = None, per_sample: int = 6, every_n_epochs: int = 1, animation_interval: int = 400)

Bases: BasePlotCallback

Evaluates the model performance over a (longer) rollout window.

This function allows evaluating the performance of the model over an extended number of rollout steps to observe long-term behavior. Add the callback to the configuration file as follows:

Example:

- _target_:  anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots
    rollout:
    - ${dataloader.validation_rollout}
    video_rollout: ${dataloader.validation_rollout}
    every_n_epochs: 1
    sample_idx: ${diagnostics.plot.sample_idx}
    parameters: ${diagnostics.plot.parameters}

The selected rollout steps for plots and video need to be lower or equal to dataloader.validation_rollout. Increasing dataloader.validation_rollout has no effect on the rollout steps during training. It ensures, that enough time steps are available for the plots and video in the validation batches.

The runtime of creating one animation of one variable for 56 rollout steps is about 1 minute. Recommended use for video generation: Fork the run using fork_run_id for 1 additional epochs and enabled videos.

on_validation_batch_end(trainer: pl.Trainer, pl_module: pl.LightningModule, output: list[torch.Tensor], batch: torch.Tensor, batch_idx: int) None

Called when the validation batch ends.

class anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot(config: OmegaConf, every_n_epochs: int | None = None)

Bases: BasePerEpochPlotCallback

Visualize the node & edge trainable features defined.

class anemoi.training.diagnostics.callbacks.plot.PlotLoss(config: OmegaConf, parameter_groups: dict[dict[str, list[str]]], every_n_batches: int | None = None)

Bases: BasePerBatchPlotCallback

Plots the unsqueezed loss over rollouts.

property sort_and_color_by_parameter_group: tuple[ndarray, ndarray, dict, list]

Sort parameters by group and prepare colors.

class anemoi.training.diagnostics.callbacks.plot.PlotSample(config: OmegaConf, sample_idx: int, parameters: list[str], accumulation_levels_plot: list[float], precip_and_related_fields: list[str] | None = None, colormaps: dict[str, Colormap] | None = None, per_sample: int = 6, every_n_batches: int | None = None, **kwargs: Any)

Bases: BasePerBatchPlotCallback

Plots a post-processed sample: input, target and prediction.

class anemoi.training.diagnostics.callbacks.plot.BasePlotAdditionalMetrics(config: OmegaConf, every_n_batches: int | None = None)

Bases: BasePerBatchPlotCallback

Base processing class for additional metrics.

class anemoi.training.diagnostics.callbacks.plot.PlotSpectrum(config: OmegaConf, sample_idx: int, parameters: list[str], min_delta: float | None = None, every_n_batches: int | None = None)

Bases: BasePlotAdditionalMetrics

Plots TP related metric comparing target and prediction.

The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones.

  • Power Spectrum

class anemoi.training.diagnostics.callbacks.plot.PlotHistogram(config: OmegaConf, sample_idx: int, parameters: list[str], precip_and_related_fields: list[str] | None = None, every_n_batches: int | None = None)

Bases: BasePlotAdditionalMetrics

Plots histograms comparing target and prediction.

The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones.

class anemoi.training.diagnostics.callbacks.provenance.ParentUUIDCallback(config: OmegaConf)

Bases: Callback

A callback that retrieves the parent UUID for a model, if it is a child model.

on_load_checkpoint(trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: torch.nn.Module) None

Called when loading a model checkpoint, use to reload state.

Parameters:
  • trainer – the current Trainer instance.

  • pl_module – the current LightningModule instance.

  • checkpoint – the full checkpoint dictionary that got loaded by the Trainer.