Diagnostics
The diagnostics module in anemoi-training is used to monitor progress during training. It is split into two parts:
tracking training to a standard machine learning tracking tool. This monitors the training and validation losses and uploads the plots created by the callbacks.
a series of callbacks, evaluated on the validation dataset, including plots of example forecasts and power spectra plots;
Trackers
By default, anemoi-training uses MLFlow tracker, but it includes functionality to use both Weights & Biases and Tensorboard.
Callbacks
The callbacks can also be used to evaluate forecasts over longer
rollouts beyond the forecast time that the model is trained on. The
number of rollout steps for verification (or forecast iteration steps)
is set using config.dataloader.validation_rollout =
*num_of_rollout_steps*.
Callbacks are configured in the config file under the
config.diagnostics key.
For regular callbacks, they can be provided as a list of dictionaries
underneath the config.diagnostics.callbacks key. Each dictionary
must have a _target_ key which is used by hydra to instantiate the
callback, any other kwarg is passed to the callback’s constructor.
callbacks:
- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
rollout:
- ${dataloader.validation_rollout}
frequency: 20
Plotting callbacks are configured in a similar way, but they are
specified underneath the config.diagnostics.plot.callbacks key.
This is done to ensure seperation and ease of configuration between
experiments.
config.diagnostics.plot is a broader config file specifying the
parameters to plot, as well as the plotting frequency, and
asynchronosity.
Setting config.diagnostics.plot.asynchronous, means that the model
training doesn’t stop whilst the callbacks are being evaluated. This is
useful for large models where the plotting can take a long time. The
plotting module uses asynchronous callbacks via asyncio and
concurrent.futures.ThreadPoolExecutor to handle plotting tasks without
blocking the main application. A dedicated event loop runs in a separate
background thread, allowing plotting tasks to be offloaded to worker
threads. This setup keeps the main thread responsive, handling
plot-related tasks asynchronously and efficiently in the background.
Plot adapter compatibility
Task-specific plot adapters normalize output handling so plotting callbacks can use the same interface across task types:
forecaster tasks use
ForecasterPlotAdapter;autoencoder tasks use
AutoencoderPlotAdapter;temporal downscaler tasks use
TemporalDownscalerPlotAdapter.
These adapters rely on the shared task _step return format
(loss, metrics, predictions) where predictions is always a list
of dataset-keyed dictionaries.
Focus Area
Plotting callbacks (such as PlotSample and PlotLoss) support a focus_area parameter. This allows you to restrict the geographic scope of plots to specific regions or masks. A focus area can be defined in two ways:
Mask Name: A
mask_attr_namestring referencing a boolean mask defined within the graph data.Lat/Lon Bounds: A
latlon_bboxlist specifying a bounding box:[lat_min, lon_min, lat_max, lon_max].
When a focus area is applied, the plot filenames and experiment log tags will automatically include a suffix (e.g., _mask_attr_name or _latlon_bbox) to distinguish them from global plots.
# Example: Focusing on multiple specific geographic region
focus_areas:
europe:
latlon_bbox: [30.0, -20.0, 60.0, 40.0]
china:
latlon_bbox: [18.0, 73.0, 54.0, 135.0]
Rendering Methods
There is an additional flag in the plotting callbacks to control the rendering method for geospatial plots, offering a trade-off between performance and detail.
- When datashader is set to True, Datashader is
used for rendering, which accelerates plotting through efficient hexbining, particularly useful for large datasets. This approach can produce smoother-looking plots due to the aggregation of data points.
- If datashader is set to False, matplotlib.scatter is used, which provides
sharper and more detailed visuals but may be slower for large datasets.
Projection
Plotting callbacks also support config.diagnostics.plot.projection_kind
to control the map projection used for geospatial figures.
equirectangular(default): regular axes, no Cartopy dependency.lambert_conformal: regional Lambert Conformal projection fitted to the plotted latitude/longitude domain (requires Cartopy).
When datashader: True is enabled, plotting is forced to
equirectangular because Datashader rendering does not support
Cartopy transforms.
Note - this asynchronous behaviour is only available for the plotting callbacks.
Progress Bar
The progress bar callback can be configured to control how training progress is displayed. This is particularly useful on HPC systems with SLURM where output is written to files, as the default RichProgressBar in PyTorch Lightning 2.6+ may not work correctly. The progress bar is controlled by two configuration options:
enable_progress_bar: A boolean flag to enable or disable the progress bar entirelyprogress_bar: Configuration for which progress bar callback to use
enable_progress_bar: True
progress_bar:
_target_: pytorch_lightning.callbacks.TQDMProgressBar
refresh_rate: 1
Lightning 2.6+ supports the (https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.RichProgressBar.html#lightning.pytorch.callbacks.RichProgressBar)[RichProgressBar], which is recommended for interactive terminals and (https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.TQDMProgressBar.html#lightning.pytorch.callbacks.TQDMProgressBar)[TQDMProgressBar] , that should be used with SLURM.
plot:
asynchronous: True # Whether to plot asynchronously
datashader: True # Whether to use datashader for plotting (faster)
projection_kind: equirectangular # or lambert_conformal (requires Cartopy)
frequency: # Frequency of the plotting
batch: 750
epoch: 5
# Parameters to plot
parameters:
- z_500
- t_850
- u_850
# Sample index
sample_idx: 0
# Precipitation and related fields
precip_and_related_fields: [tp, cp]
datasets_to_plot: ["data"]
focus_areas:
europe:
latlon_bbox: [30.0, -20.0, 60.0, 40.0]
china:
latlon_bbox: [18.0, 73.0, 54.0, 135.0]
callbacks:
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
dataset_names: ["your_dataset_name"]
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
dataset_names: ["your_dataset_name"]
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
parameters: ${diagnostics.plot.parameters}
Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure:
Plot adapter: single entry point for diagnostics callbacks.
Groups the plot-related hooks so task classes expose one attribute (plot_adapter) instead of five small methods.
The EnsemblePlotAdapterWrapper allows to wrap any task-specific adapter, adding ensemble member handling without modifying the inner adapter’s logic.
- class anemoi.training.diagnostics.callbacks.plot_adapter.BasePlotAdapter(task: BaseTask)
Bases:
ABCAbstract plotting contract. Subclasses define output_times, get_init_step, iter_plot_samples.
- select_members(tensor: Any, members: int | list[int] | None = None) Any
Select ensemble members from tensor. No-op for non-ensemble adapters.
- class anemoi.training.diagnostics.callbacks.plot_adapter.ForecasterPlotAdapter(task: BaseTask)
Bases:
BasePlotAdapterPlot Adapter to adapt plots to the rollout set-up of the Forecaster Task.
Handles multiple loss plots, n_step_output targets per step, multi-step iter.
- class anemoi.training.diagnostics.callbacks.plot_adapter.TemporalDownscalerPlotAdapter(task: BaseTask)
Bases:
BasePlotAdapterPlot Adapter for TemporalDownscaler Task.
Handles squeezing (1, n_step_output, …) -> (n_step_output, …).
- class anemoi.training.diagnostics.callbacks.plot_adapter.AutoencoderPlotAdapter(task: BaseTask)
Bases:
BasePlotAdapterPlot Adapter for Autoencoder Task: single (sample, recon, tag) yield.
- class anemoi.training.diagnostics.callbacks.plot_adapter.EnsemblePlotAdapterWrapper(inner: BasePlotAdapter)
Bases:
BasePlotAdapterWraps any task-specific adapter, adding ensemble member handling.
This adapter decorates an inner (task-specific) adapter to handle the extra ensemble dimension present in ensemble training outputs. Batch shape convention: (B, T, E, G, V) where E is ensemble members.