Diagnostics
The diagnostics module in anemoi-training is used to monitor progress during training. It is split into two parts:
tracking training to a standard machine learning tracking tool. This monitors the training and validation losses and uploads the plots created by the callbacks.
a series of callbacks, evaluated on the validation dataset, including plots of example forecasts and power spectra plots;
Trackers
By default, anemoi-training uses MLFlow tracker, but it includes functionality to use both Weights & Biases and Tensorboard.
Callbacks
The callbacks can also be used to evaluate forecasts over longer
rollouts beyond the forecast time that the model is trained on. The
number of rollout steps for verification (or forecast iteration steps)
is set using config.dataloader.validation_rollout =
*num_of_rollout_steps*
.
Callbacks are configured in the config file under the
config.diagnostics
key.
For regular callbacks, they can be provided as a list of dictionaries
underneath the config.diagnostics.callbacks
key. Each dictionary
must have a _target
key which is used by hydra to instantiate the
callback, any other kwarg is passed to the callback’s constructor.
callbacks:
- _target_: anemoi.training.diagnostics.callbacks.evaluation.RolloutEval
rollout: ${dataloader.validation_rollout}
frequency: 20
Plotting callbacks are configured in a similar way, but they are
specified underneath the config.diagnostics.plot.callbacks
key.
This is done to ensure seperation and ease of configuration between experiments.
config.diagnostics.plot
is a broader config file specifying the
parameters to plot, as well as the plotting frequency, and
asynchronosity.
Setting config.diagnostics.plot.asynchronous
, means that the model
training doesn’t stop whilst the callbacks are being evaluated. This is
useful for large models where the plotting can take a long time. The
plotting module uses asynchronous callbacks via asyncio and
concurrent.futures.ThreadPoolExecutor to handle plotting tasks without
blocking the main application. A dedicated event loop runs in a separate
background thread, allowing plotting tasks to be offloaded to worker
threads. This setup keeps the main thread responsive, handling
plot-related tasks asynchronously and efficiently in the background.
There is an additional flag in the plotting callbacks to control the rendering method for geospatial plots, offering a trade-off between performance and detail. When datashader is set to True, Datashader is used for rendering, which accelerates plotting through efficient hexbining, particularly useful for large datasets. This approach can produce smoother-looking plots due to the aggregation of data points. If datashader is set to False, matplotlib.scatter is used, which provides sharper and more detailed visuals but may be slower for large datasets.
Note - this asynchronous behaviour is only available for the plotting callbacks.
plot:
asynchronous: True # Whether to plot asynchronously
datashader: True # Whether to use datashader for plotting (faster)
frequency: # Frequency of the plotting
batch: 750
epoch: 5
# Parameters to plot
parameters:
- z_500
- t_850
- u_850
# Sample index
sample_idx: 0
# Precipitation and related fields
precip_and_related_fields: [tp, cp]
callbacks:
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotLoss
# group parameters by categories when visualizing contributions to the loss
# one-parameter groups are possible to highlight individual parameters
parameter_groups:
moisture: [tp, cp, tcw]
sfc_wind: [10u, 10v]
- _target_: anemoi.training.diagnostics.callbacks.plot.PlotSample
sample_idx: ${diagnostics.plot.sample_idx}
per_sample : 6
parameters: ${diagnostics.plot.parameters}
Below is the documentation for the default callbacks provided, but it is also possible for users to add callbacks using the same structure:
- class anemoi.training.diagnostics.callbacks.checkpoint.AnemoiCheckpoint(config: OmegaConf, **kwargs: dict)
Bases:
ModelCheckpoint
A checkpoint callback that saves the model after every validation epoch.
- on_train_end(trainer: pl.Trainer, pl_module: pl.LightningModule) None
Save the last checkpoint at the end of training.
If the candidates aren’t better than the last checkpoint, then no checkpoints are saved. Note - this method if triggered when using max_epochs, it won’t save any checkpoints since the monitor candidates won’t show any changes with regard the the ‘on_train_epoch_end’ hook.
- class anemoi.training.diagnostics.callbacks.evaluation.RolloutEval(config: OmegaConf, rollout: int, every_n_batches: int)
Bases:
Callback
Evaluates the model performance over a (longer) rollout window.
- class anemoi.training.diagnostics.callbacks.optimiser.LearningRateMonitor(config: DictConfig, logging_interval: str = 'step', log_momentum: bool = False)
Bases:
LearningRateMonitor
Provide LearningRateMonitor from pytorch_lightning as a callback.
- class anemoi.training.diagnostics.callbacks.optimiser.StochasticWeightAveraging(config: DictConfig, swa_lrs: int | None = None, swa_epoch_start: int | None = None, annealing_epochs: int | None = None, annealing_strategy: str | None = None, device: str | None = None, **kwargs)
Bases:
StochasticWeightAveraging
Provide StochasticWeightAveraging from pytorch_lightning as a callback.
- class anemoi.training.diagnostics.callbacks.plot.BasePlotCallback(config: BaseSchema)
Bases:
Callback
,ABC
Factory for creating a callback that plots data to Experiment Logging.
- class anemoi.training.diagnostics.callbacks.plot.BasePerBatchPlotCallback(config: OmegaConf, every_n_batches: int | None = None)
Bases:
BasePlotCallback
Base Callback for plotting at the end of each batch.
- class anemoi.training.diagnostics.callbacks.plot.BasePerEpochPlotCallback(config: OmegaConf, every_n_epochs: int | None = None)
Bases:
BasePlotCallback
Base Callback for plotting at the end of each epoch.
- class anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots(config: OmegaConf, rollout: list[int], sample_idx: int, parameters: list[str], video_rollout: int = 0, accumulation_levels_plot: list[float] | None = None, colormaps: dict[str, Colormap] | None = None, per_sample: int = 6, every_n_epochs: int = 1, animation_interval: int = 400)
Bases:
BasePlotCallback
Evaluates the model performance over a (longer) rollout window.
This function allows evaluating the performance of the model over an extended number of rollout steps to observe long-term behavior. Add the callback to the configuration file as follows:
Example:
- _target_: anemoi.training.diagnostics.callbacks.plot.LongRolloutPlots rollout: - ${dataloader.validation_rollout} video_rollout: ${dataloader.validation_rollout} every_n_epochs: 1 sample_idx: ${diagnostics.plot.sample_idx} parameters: ${diagnostics.plot.parameters}
The selected rollout steps for plots and video need to be lower or equal to dataloader.validation_rollout. Increasing dataloader.validation_rollout has no effect on the rollout steps during training. It ensures, that enough time steps are available for the plots and video in the validation batches.
The runtime of creating one animation of one variable for 56 rollout steps is about 1 minute. Recommended use for video generation: Fork the run using fork_run_id for 1 additional epochs and enabled videos.
- class anemoi.training.diagnostics.callbacks.plot.GraphTrainableFeaturesPlot(config: OmegaConf, every_n_epochs: int | None = None)
Bases:
BasePerEpochPlotCallback
Visualize the node & edge trainable features defined.
- class anemoi.training.diagnostics.callbacks.plot.PlotLoss(config: OmegaConf, parameter_groups: dict[dict[str, list[str]]], every_n_batches: int | None = None)
Bases:
BasePerBatchPlotCallback
Plots the unsqueezed loss over rollouts.
- class anemoi.training.diagnostics.callbacks.plot.PlotSample(config: OmegaConf, sample_idx: int, parameters: list[str], accumulation_levels_plot: list[float], precip_and_related_fields: list[str] | None = None, colormaps: dict[str, Colormap] | None = None, per_sample: int = 6, every_n_batches: int | None = None, **kwargs: Any)
Bases:
BasePerBatchPlotCallback
Plots a post-processed sample: input, target and prediction.
- class anemoi.training.diagnostics.callbacks.plot.BasePlotAdditionalMetrics(config: OmegaConf, every_n_batches: int | None = None)
Bases:
BasePerBatchPlotCallback
Base processing class for additional metrics.
- class anemoi.training.diagnostics.callbacks.plot.PlotSpectrum(config: OmegaConf, sample_idx: int, parameters: list[str], min_delta: float | None = None, every_n_batches: int | None = None)
Bases:
BasePlotAdditionalMetrics
Plots TP related metric comparing target and prediction.
The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones.
Power Spectrum
- class anemoi.training.diagnostics.callbacks.plot.PlotHistogram(config: OmegaConf, sample_idx: int, parameters: list[str], precip_and_related_fields: list[str] | None = None, every_n_batches: int | None = None)
Bases:
BasePlotAdditionalMetrics
Plots histograms comparing target and prediction.
The actual increment (output - input) is plot for prognostic variables while the output is plot for diagnostic ones.
- class anemoi.training.diagnostics.callbacks.provenance.ParentUUIDCallback(config: OmegaConf)
Bases:
Callback
A callback that retrieves the parent UUID for a model, if it is a child model.
- on_load_checkpoint(trainer: pl.Trainer, pl_module: pl.LightningModule, checkpoint: torch.nn.Module) None
Called when loading a model checkpoint, use to reload state.
- Parameters:
trainer – the current
Trainer
instance.pl_module – the current
LightningModule
instance.checkpoint – the full checkpoint dictionary that got loaded by the Trainer.