Residual connections

Residual connections are a key architectural feature in Anemoi’s encoder-processor-decoder models, enabling more effective information flow and gradient propagation across network layers. Residual connections help mitigate issues such as vanishing gradients and support the training of deeper, and more expressive models.

The configurable residual connections link input data to output data. The type of residual connection used in a model is specified under the residual key in the model configuration YAML. This modular approach allows users to select and customize the residual strategy best suited for their forecasting task, whether it be a standard skip connection or a truncated connection.

Standard residual (default)

The standard residual formulation used in most models is:

\[x(t+1) = x(t) + f_\theta(x(t))\]

where \(f_\theta\) is the learned model increment. This preserves the full input state and adds a correction.

Skip Connection

Returns the most recent timestep unchanged:

\[\text{residual}(x) = x_t\]

This is the default residual and corresponds to the standard formulation above (the model output is added externally by the architecture).

class anemoi.models.layers.residual.SkipConnection(step: int = -1, **_)

Bases: BaseResidualConnection

Skip connection module

This layer returns the most recent timestep from the input sequence.

This module is used to bypass processing layers and directly pass the latest input forward.

forward(x: Tensor, grid_shard_sizes=None, model_comm_group=None, n_step_output: int | None = None) Tensor

Return the last timestep of the input sequence.

Truncated Connection

Projects the input to a coarser grid and back, removing high-frequency content from the skip connection via sparse spatial projections:

\[\text{residual}(x) = P_{\text{up}} \, P_{\text{down}} \, x_t\]

where \(P_{\text{down}}\) maps to the coarse grid and \(P_{\text{up}}\) maps back to the original resolution.

class anemoi.models.layers.residual.TruncatedConnection(graph: HeteroData | None = None, src_node_weight_attribute: str | None = None, edge_weight_attribute: str | None = None, truncation_config: dict | None = None, truncation_up_edges_name: tuple[str, str, str] | None = None, truncation_down_edges_name: tuple[str, str, str] | None = None, data_node_name: str = 'data', autocast: bool = False, row_normalize: bool = False, truncation_up_file_path: str | None = None, truncation_down_file_path: str | None = None, **_)

Bases: BaseResidualConnection

Truncated skip connection.

Applies a coarse-graining and reconstruction of input features using sparse projections to truncate high-frequency features.

Edge names and the edge-weight attribute are expected to be pre-resolved by ProjectionCreator and passed in directly. File-path loading is still supported as an alternative to the graph-based path.

Parameters:
  • graph (HeteroData, optional) – Graph containing the truncation subgraphs.

  • src_node_weight_attribute (str, optional) – Source-node attribute used as additional projection weights.

  • edge_weight_attribute (str, optional) – Edge attribute used as projection weights (default: gauss_weight).

  • truncation_config (dict, optional) – Configuration used to build or load the truncation projections.

  • truncation_up_edges_name (tuple[str, str, str], optional) – Pre-resolved (src, relation, dst) edge type for the up-projection.

  • truncation_down_edges_name (tuple[str, str, str], optional) – Pre-resolved (src, relation, dst) edge type for the down-projection.

  • data_node_name (str, default "data") – Name of the data nodes in graph.

  • autocast (bool, default False) – Whether to use automatic mixed precision for the projections.

  • row_normalize (bool, optional) – Normalize projection weights per target node so each row sums to 1.

  • truncation_up_file_path (str, optional) – Deprecated path to an .npz file for the up-projection matrix.

  • truncation_down_file_path (str, optional) – Deprecated path to an .npz file for the down-projection matrix.

Examples

>>> # Graph-based path (edge names supplied by ProjectionCreator)
>>> conn = TruncatedConnection(
...     graph=graph,
...     data_node_name="data",
...     truncation_down_edges_name=("data", "to", "truncation"),
...     truncation_up_edges_name=("truncation", "to", "data"),
...     edge_weight_attribute="gauss_weight",
... )
>>> x = torch.randn(2, 4, 1, 40192, 44)  # (batch, time, ens, nodes, features)
>>> out = conn(x)
>>> print(out.shape)
torch.Size([2, 4, 1, 40192, 44])
>>> # File-based path
>>> conn = TruncatedConnection(
...     truncation_down_file_path="n320_to_o96.npz",
...     truncation_up_file_path="o96_to_n320.npz",
... )
>>> x = torch.randn(2, 4, 1, 40192, 44)
>>> out = conn(x)
>>> print(out.shape)
torch.Size([2, 4, 1, 40192, 44])
forward(x: Tensor, grid_shard_sizes=None, model_comm_group=None, n_step_output: int | None = None) Tensor

Apply truncated skip connection.

Configuration

Both connection types are configured under the residual key in the model config. TruncatedConnection accepts sibling-class kwargs such as step transparently, so switching between connection types requires only changing _target_.

TruncatedConnection supports two modes, both via the truncation_config key:

  • On-the-fly: the truncation subgraph is built at runtime from the main graph using a coarser grid specification.

  • File-based: precomputed .npz projection matrices are loaded from disk.

Choose one mode per config; do not mix the two within the same truncation_config block.

On-the-fly example:

model:
  residual:
    _target_: anemoi.models.layers.residual.TruncatedConnection
    truncation_config:
      grid: o32
      num_nearest_neighbours: 3
      sigma: 1.0

File-based example:

model:
  residual:
    _target_: anemoi.models.layers.residual.TruncatedConnection
    truncation_config:
      truncation_down_file_path: /path/to/O96-O32-grid-box-average.mat.npz
      truncation_up_file_path: /path/to/O32-O96-grid-box-average.mat.npz
      row_normalize: false

Note

The top-level truncation_up_file_path and truncation_down_file_path kwargs are still accepted for backward compatibility, but the recommended approach is to move them inside truncation_config.

Learnable residual (Ornstein)

Learnable residual connections introduce a trainable scaling parameter \(\alpha\) on the residual connection, giving a formulation equivalent to a discretized Ornstein–Uhlenbeck process:

\[x(t+1) = \alpha \cdot x(t) + f_\theta(x(t))\]

With \(\alpha\) trainable and \(\alpha < 1\), errors in the state are contracted at each step rather than perfectly preserved. This bounds error growth during autoregressive integration.

Two variants are available, offering increasing degrees of spatial structure in the learnable parameters.

Scalar Ornstein Connection

A single learnable scalar \(\alpha_v\) per prognostic variable \(v\):

\[\text{residual}(x)_v = (1 - \alpha_v) \cdot x_{t,v}\]

where \(\alpha_v \in (\alpha_{\text{buff}}, 1)\) is parameterized via a sigmoid. This is the simplest Ornstein variant – no spatial structure, just a per-variable damping factor.

class anemoi.models.layers.residual.ScalarOrnsteinConnection(theta_init: float = 0.0, theta_buff: float = 0.0, theta_train: bool = True, regressors: list[str] | None = None, graph: HeteroData | None = None, statistics: dict | None = None, data_indices=None, dataset_name: str | None = None, **_)

Bases: BaseResidualConnection

Ornstein residual with learnable scalars theta and mu.

residual(x) = (1 - theta) * x + mu + sum_i beta_i * f_i

where theta is in (theta_buff, 1) and learned independently for each prognostic variable. f_i are forcing variables listed in regressors No spatial or spectral structure.

Parameters:
  • theta_init (float) – Initial value for theta. If 0 and statistics are available, auto-initialized from tendency statistics.

  • theta_buff (float) – Lower bound buffer for theta. Theta is constrained to (theta_buff, 1).

  • theta_train (bool) – Whether theta is a trainable parameter.

  • regressors (list[str] | None) – Variable names to use as regressors.

forward(x: Tensor, grid_shard_sizes=None, model_comm_group=None, n_step_output: int | None = None) Tensor

Define the residual connection operation.

Should be overridden by subclasses.

Spectral Ornstein Connection

Spatially-varying \(\alpha\) and bias \(\mu\), defined as smooth functions on the sphere via spherical harmonic (SH) coefficients:

\[\text{residual}(x)_v = \bigl(1 - \alpha_v(s)\bigr) \cdot x_{t,v} + \mu_v(s) + \sum_i \beta_{i,v}(s) \cdot f_i\]

where \(s\) denotes the spatial location, \(\alpha_v(s)\), \(\mu_v(s)\), and \(\beta_{i,v}(s)\) are reconstructed from low-order SH coefficients (controlled by lmax), and \(f_i\) are optional forcing regressors.

When truncate=True, a learnable spectral low-pass filter is applied to the input fields before computing the residual. This removes high-frequency content from the skip connection, forcing the model to reconstruct fine-scale detail from scratch. An optional anti-aliasing blend (anti_aliasing=True) smoothly mixes the filtered and unfiltered fields.

class anemoi.models.layers.residual.SpectralOrnsteinConnection(lmax: int = 2, grid: str = 'regular', theta_init: float = 0.0, theta_buff: float = 0.0, use_mean: bool = True, regressors: list[str] | None = None, truncate: bool = False, skip_truncate_variables: list[str] | None = None, anti_aliasing: bool = True, graph: HeteroData | None = None, statistics: dict | None = None, data_indices=None, dataset_name: str | None = None, **_)

Bases: BaseResidualConnection

Ornstein residual with learnable spatially-varying theta and mu defined via spherical harmonics.

residual(x) = (1 - theta(s)) * x + mu(s) + sum_i beta_i(s) * f_i

where theta/mu/beta_i are stored as lmax x lmax complex SH coefficients (per prognostic variable), and the spatial fields are obtained via inverse SHT. f_i are forcing variables listed in regressors.

When truncate=True, a learnable spectral low-pass filter is applied to the input fields before computing the residual. This truncates high-frequency content from the skip connection.

Parameters:
  • lmax (int) – Maximum spherical harmonic degree for the theta/mu coefficients.

  • grid (str) – Grid type: "regular" for regular lat-lon, "octahedral" for octahedral reduced grids. Other types are not currently supported and will raise an error.

  • theta_init (float) – Initial value for theta.

  • theta_buff (float) – Lower bound buffer for theta.

  • use_mean (bool) – Whether to include a the mean (mu) term.

  • regressors (list[str] | None) – Variable names to use as spatially-varying regressors.

  • truncate (bool) – If True, apply a learnable spectral low-pass filter to the input fields.

  • skip_truncate_variables (list[str] | None) – Variable names to exclude from spectral truncation (only used when truncate=True).

  • anti_aliasing (bool) – If True (and truncate=True), use anti-aliasing blending in the filter.

forward(x: Tensor, grid_shard_sizes=None, model_comm_group=None, n_step_output: int | None = None) Tensor

Define the residual connection operation.

Should be overridden by subclasses.