Diagnostics

The diagnostics module in anemoi-training is used to monitor progress during training. It is split into two parts:

  1. tracking training to a standard machine learning tracking tool. This monitors the training and validation losses and uploads the plots created by the callbacks.

  2. a series of callbacks, evaluated on the validation dataset, including plots of example forecasts and power spectra plots;

Trackers

By default, anemoi-training uses MLFlow tracker, but it includes functionality to use both Weights & Biases and Tensorboard.

Callbacks

The callbacks can also be used to evaluate forecasts over longer rollouts beyond the forecast time that the model is trained on. The number of rollout steps for verification (or forecast iteration steps) is set using config.dataloader.validation_rollout = *num_of_rollout_steps*.

Callbacks are configured in the config file under the config.diagnostics key.

For regular callbacks, they can be provided as a list of dictionaries underneath the config.diagnostics.callbacks key. Each dictionary must have a _target_ key which is used by hydra to instantiate the callback, any other kwarg is passed to the callback’s constructor.

callbacks:
   - _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
   rollout:
   - ${dataloader.validation_rollout}
   frequency: 20

Plotting callbacks are configured in a similar way, but they are specified underneath the config.diagnostics.plot.callbacks key. This is done to ensure seperation and ease of configuration between experiments. config.diagnostics.plot is a broader config file specifying the parameters to plot, as well as the plotting frequency, and asynchronosity.

Setting config.diagnostics.plot.asynchronous, means that the model training doesn’t stop whilst the callbacks are being evaluated. This is useful for large models where the plotting can take a long time. The plotting module uses asynchronous callbacks via asyncio and concurrent.futures.ThreadPoolExecutor to handle plotting tasks without blocking the main application. A dedicated event loop runs in a separate background thread, allowing plotting tasks to be offloaded to worker threads. This setup keeps the main thread responsive, handling plot-related tasks asynchronously and efficiently in the background.

Plot adapter compatibility

Task-specific plot adapters normalize output handling so plotting callbacks can use the same interface across task types:

  • forecaster tasks use ForecasterPlotAdapter;

  • autoencoder tasks use AutoencoderPlotAdapter;

  • temporal downscaler tasks use TemporalDownscalerPlotAdapter.

These adapters rely on the shared task _step return format (loss, metrics, predictions) where predictions is always a list of dataset-keyed dictionaries.

Focus Area

Plotting callbacks (such as PlotSample and PlotLoss) support a focus_area parameter. This allows you to restrict the geographic scope of plots to specific regions or masks. A focus area can be defined in two ways:

  • Mask Name: A mask_attr_name string referencing a boolean mask defined within the graph data.

  • Lat/Lon Bounds: A latlon_bbox list specifying a bounding box: [lat_min, lon_min, lat_max, lon_max].

When a focus area is applied, the plot filenames and experiment log tags will automatically include a suffix (e.g., _mask_attr_name or _latlon_bbox) to distinguish them from global plots.

# Example: Focusing on multiple specific geographic region
focus_areas:
   europe:
      latlon_bbox: [30.0, -20.0, 60.0, 40.0]
   china:
      latlon_bbox: [18.0, 73.0, 54.0, 135.0]

Rendering Methods

There is an additional flag in the plotting callbacks to control the rendering method for geospatial plots, offering a trade-off between performance and detail.

  • When datashader is set to True, Datashader is

    used for rendering, which accelerates plotting through efficient hexbining, particularly useful for large datasets. This approach can produce smoother-looking plots due to the aggregation of data points.

  • If datashader is set to False, matplotlib.scatter is used, which provides

    sharper and more detailed visuals but may be slower for large datasets.

Projection

Plotting callbacks also support config.diagnostics.plot.projection_kind to control the map projection used for geospatial figures.

  • equirectangular (default): regular axes, no Cartopy dependency.

  • lambert_conformal: regional Lambert Conformal projection fitted to the plotted latitude/longitude domain (requires Cartopy).

When datashader: True is enabled, plotting is forced to equirectangular because Datashader rendering does not support Cartopy transforms.

Note - this asynchronous behaviour is only available for the plotting callbacks.

Progress Bar

The progress bar callback can be configured to control how training progress is displayed. This is particularly useful on HPC systems with SLURM where output is written to files, as the default RichProgressBar in PyTorch Lightning 2.6+ may not work correctly. The progress bar is controlled by two configuration options:

  • enable_progress_bar: A boolean flag to enable or disable the progress bar entirely

  • progress_bar: Configuration for which progress bar callback to use

enable_progress_bar: True
progress_bar:
  _target_: pytorch_lightning.callbacks.TQDMProgressBar
  refresh_rate: 1

Lightning 2.6+ supports the (https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichProgressBar.html#lightning.pytorch.callbacks.RichProgressBar)[RichProgressBar], which is recommended for interactive terminals and (https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.TQDMProgressBar.html#lightning.pytorch.callbacks.TQDMProgressBar)[TQDMProgressBar] , that should be used with SLURM.

plot:
   asynchronous: True # Whether to plot asynchronously
   datashader: True # Whether to use datashader for plotting (faster)
   projection_kind: equirectangular # or lambert_conformal (requires Cartopy)
   frequency: # Frequency of the plotting
   batch: 750
   epoch: 5

   # Parameters to plot
   parameters:
      - z_500
      - t_850
      - u_850

   # Sample index
   sample_idx: 0

   # Precipitation and related fields
   precip_and_related_fields: [tp, cp]

   datasets_to_plot: ["data"]

   focus_areas:
      europe:
         latlon_bbox: [30.0, -20.0, 60.0, 40.0]
      china:
         latlon_bbox: [18.0, 73.0, 54.0, 135.0]

   callbacks:
      - _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
         dataset_names: ["your_dataset_name"]
         # group parameters by categories when visualizing contributions to the loss
         # one-parameter groups are possible to highlight individual parameters
         parameter_groups:
            moisture: [tp, cp, tcw]
            sfc_wind: [10u, 10v]

      - _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
         dataset_names: ["your_dataset_name"]
         sample_idx: ${diagnostics.plot.sample_idx}
         per_sample : 6
         parameters: ${diagnostics.plot.parameters}

Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure:

Plot adapter: single entry point for diagnostics callbacks.

Groups the plot-related hooks so task classes expose one attribute (plot_adapter) instead of five small methods.

The EnsemblePlotAdapterWrapper allows to wrap any task-specific adapter, adding ensemble member handling without modifying the inner adapter’s logic.

class anemoi.training.diagnostics.callbacks.plot_adapter.BasePlotAdapter(task: BaseTask)

Bases: ABC

Abstract plotting contract. Subclasses define output_times, get_init_step, iter_plot_samples.

select_members(tensor: Any, members: int | list[int] | None = None) Any

Select ensemble members from tensor. No-op for non-ensemble adapters.

prepare_loss_batch(batch: dict) dict

Prepare batch for loss plotting. No-op for non-ensemble adapters.

abstractmethod iter_plot_samples(data: Any, output_tensor: Any) Iterator[tuple[Any, Any, Any, str]]

Yield (x, y_true, y_pred, tag_suffix) or (sample, recon, tag) per plot sample.

class anemoi.training.diagnostics.callbacks.plot_adapter.ForecasterPlotAdapter(task: BaseTask)

Bases: BasePlotAdapter

Plot Adapter to adapt plots to the rollout set-up of the Forecaster Task.

Handles multiple loss plots, n_step_output targets per step, multi-step iter.

iter_plot_samples(data: Any, output_tensor: Any) Iterator[tuple[Any, Any, Any, str]]

Yield (x, y_true, y_pred, tag_suffix) or (sample, recon, tag) per plot sample.

class anemoi.training.diagnostics.callbacks.plot_adapter.TemporalDownscalerPlotAdapter(task: BaseTask)

Bases: BasePlotAdapter

Plot Adapter for TemporalDownscaler Task.

Handles squeezing (1, n_step_output, …) -> (n_step_output, …).

iter_plot_samples(data: Any, output_tensor: Any) Iterator[tuple[Any, Any, Any, str]]

Yield (x, y_true, y_pred, tag_suffix) or (sample, recon, tag) per plot sample.

class anemoi.training.diagnostics.callbacks.plot_adapter.AutoencoderPlotAdapter(task: BaseTask)

Bases: BasePlotAdapter

Plot Adapter for Autoencoder Task: single (sample, recon, tag) yield.

iter_plot_samples(data: Any, output_tensor: Any) Iterator[tuple[Any, Any, Any, str]]

Yield (x, y_true, y_pred, tag_suffix) or (sample, recon, tag) per plot sample.

class anemoi.training.diagnostics.callbacks.plot_adapter.EnsemblePlotAdapterWrapper(inner: BasePlotAdapter)

Bases: BasePlotAdapter

Wraps any task-specific adapter, adding ensemble member handling.

This adapter decorates an inner (task-specific) adapter to handle the extra ensemble dimension present in ensemble training outputs. Batch shape convention: (B, T, E, G, V) where E is ensemble members.

select_members(tensor: Any, members: int | list[int] | None = None) Any

Slice ensemble members from dim 2 of the output tensor.

Parameters:
  • tensor (Any) – Tensor with shape (…, members, grid, vars)

  • members (int | list[int] | None) – Members to select. None returns all members, int/list selects specific members.

prepare_loss_batch(batch: dict) dict

Squeeze ensemble dim to member 0 for loss plotting.

iter_plot_samples(data: Any, output_tensor: Any) Iterator[tuple[Any, Any, Any, str]]

Yield (x, y_true, y_pred, tag_suffix) or (sample, recon, tag) per plot sample.