Ensemble CRPS-based training
This guide is intended for users who want to train an ensemble CRPS-based model and are already familiar with the basic training configurations.
It focuses on the changes relative to deterministic training. Detailed reference material for graph-based truncation, multiscale loss configuration, and residual projections lives in Field Truncation, Time Aggregate Loss Functions, and Residual connections.
The CRPS training requires the following changes to the deterministic training:
Component |
Deterministic |
CRPS |
|---|---|---|
Forecaster |
|
|
Strategy |
|
|
Training loss |
|
|
Model |
|
|
Datamodule |
|
|
Changes in System config
system:
input:
truncation: ${data.resolution}-O32-linear.mat.npz
truncation_inv: O32-${data.resolution}-linear.mat.npz
hardware:
The truncation and truncation_inv can be used in the deterministic or CRPS training. As described in Field Truncation, truncation smooths the skipped connection and can also be reused for multiscale loss computation.
Graph-based truncation is also supported via
graph.projections.truncation together with
model.residual: TruncatedConnection. The canonical graph-based and
file-based examples are documented in
Field Truncation and
Residual connections.
The CRPS training uses a different DDP strategy which requires to specify the number of GPUs per ensemble.
Changes in datamodule config
The datamodule needs to be set to
AnemoiEnsDatasetsDataModule.
AnemoiEnsDatasetsDataModule can be used with a single initial
condition for all ensembles or with perturbed initial conditions. The
perturbed initial conditions need to be part of your dataset.
Changes in model config
The config group for the model is set to transformer_ens.yaml, which
specifies the AnemoiEnsModelEncProcDec class with the Graph
Transformer encoder/decoder and a transformer processor.
Changes in transformer_ens.yaml with respect to transformer.yaml are:
model:
model:
_target_: anemoi.models.models.ens_encoder_processor_decoder.AnemoiEnsModelEncProcDec
A different model class is used for CRPS training.
noise_injector:
_target_: anemoi.models.layers.ensemble.NoiseConditioning
noise_std: 1
noise_channels_dim: 4
noise_mlp_hidden_dim: 32
inject_noise: True
Each ensemble member samples random noise at every time step. The noise is embedded and injected into the latent space of the processor using a conditional layer norm.
Noise can optionally be projected before conditioning:
noise_matrixloads a precomputed sparse matrix from disk.noise_edges_namepoints to a graph edge type that maps a custom source node set to the hidden grid.
In the graph-based form, define the source nodes and corresponding source-to-hidden edge in the graph config, then point the noise injector to that edge:
graph:
nodes:
noise:
node_builder:
_target_: anemoi.graphs.nodes.ReducedGaussianGridNodes
grid: o32
edges:
- source_name: noise
target_name: hidden
edge_builders:
- _target_: anemoi.graphs.edges.KNNEdges
num_nearest_neighbours: 32
attributes:
gauss_weight:
_target_: anemoi.graphs.edges.attributes.GaussianDistanceWeights
norm: l1
sigma: 0.1
model:
noise_injector:
_target_: anemoi.models.layers.ensemble.NoiseConditioning
noise_std: 1
noise_channels_dim: 4
noise_mlp_hidden_dim: 32
noise_edges_name: [noise, to, hidden]
edge_weight_attribute: gauss_weight
In order to condition the latent space on the noise, we need to use a
different layer norm in the processor, here the
anemoi.models.layers.normalization.ConditionalLayerNorm.
See Models for the ensemble model
architecture and Normalization for the
normalization layers.
Changes in training config
training:
method:
_target_: anemoi.training.train.methods.EnsembleTraining
ensemble_size_per_device: 4
max_epochs: 20
The model task is set to
anemoi.training.train.tasks.GraphEnsForecaster for CRPS
training to deal with the ensemble members. The number of ensemble
members per device needs to be specified.
Note
The total number of ensemble members is the product of the ensemble_size_per_device and the ratio of num_gpus_per_ensemble to num_gpus_per_model .
strategy:
_target_: anemoi.training.distributed.strategy.DDPEnsGroupStrategy
num_gpus_per_ensemble: ${system.hardware.num_gpus_per_ensemble}
num_gpus_per_model: ${system.hardware.num_gpus_per_model}
The CRPS training uses a different Strategy which allows to parallelise the training over the ensemble members and shard the model.
training_loss:
datasets:
your_dataset_name:
_target_: anemoi.training.losses.CRPS
scalers: ["variable"]
ignore_nans: False
alpha: 0.95
We need to specify the loss function for the CRPS training. Here, we use
the anemoi.training.losses.kcrps.AlmostFairKernelCRPS loss
function (Lang et al. (2024b)):
The alpha parameter is a trade-off parameter between the CRPS and the fair CRPS.
If you want multiscale CRPS training, the reference documentation for
MultiscaleLossWrapper and the two supported loss_matrices_graph
forms is in Time Aggregate Loss Functions.
Typically, the validation metrics are the same as the training loss, but different validation metrics can be added here (see Losses).
Example config
A typical config file for CRPS training is:
defaults:
- data: zarr
- dataloader: native_grid
- diagnostics: evaluation
- system: example
- graph: encoder_decoder_only
- model: transformer_ens
- task: forecaster
- training: default
- _self_
config_validation: True
# Changes in system
system:
input:
truncation: ${data.resolution}-O32-linear.mat.npz
truncation_inv: O32-${data.resolution}-linear.mat.npz
hardware:
num_gpus_per_ensemble: 1
num_gpus_per_node: 1
num_nodes: 1
num_gpus_per_model: 1
data:
resolution: o96
# Changes in training
training:
method:
_target_: anemoi.training.train.methods.EnsembleTraining
ensemble_size_per_device: 4
max_epochs: 20
# Changes in strategy
strategy:
_target_: anemoi.training.distributed.strategy.DDPEnsGroupStrategy
num_gpus_per_ensemble: ${system.hardware.num_gpus_per_ensemble}
num_gpus_per_model: ${system.hardware.num_gpus_per_model}
# Changes in training loss
training_loss:
datasets:
your_dataset_name:
_target_: anemoi.training.losses.CRPS
scalers: ["variable"]
ignore_nans: False
alpha: 0.95
# Changes in validation metrics
validation_metrics:
datasets:
your_dataset_name:
kcrps:
_target_: anemoi.training.losses.CRPS
scalers: []
ignore_nans: False
alpha: 1.0