Residual connections
Residual connections are a key architectural feature in Anemoi’s encoder-processor-decoder models, enabling more effective information flow and gradient propagation across network layers. Residual connections help mitigate issues such as vanishing gradients and support the training of deeper, and more expressive models.
The configurable residual connections link input data to output data.
The type of residual connection used in a model is specified under the
residual key in the model configuration YAML. This modular approach
allows users to select and customize the residual strategy best suited
for their forecasting task, whether it be a standard skip connection or
a truncated connection.
Standard residual (default)
The standard residual formulation used in most models is:
where \(f_\theta\) is the learned model increment. This preserves the full input state and adds a correction.
Skip Connection
Returns the most recent timestep unchanged:
This is the default residual and corresponds to the standard formulation above (the model output is added externally by the architecture).
Truncated Connection
Projects the input to a coarser grid and back, removing high-frequency content from the skip connection via sparse spatial projections:
where \(P_{\text{down}}\) maps to the coarse grid and \(P_{\text{up}}\) maps back to the original resolution.
- class anemoi.models.layers.residual.TruncatedConnection(graph: HeteroData | None = None, src_node_weight_attribute: str | None = None, edge_weight_attribute: str | None = None, truncation_config: dict | None = None, truncation_up_edges_name: tuple[str, str, str] | None = None, truncation_down_edges_name: tuple[str, str, str] | None = None, data_node_name: str = 'data', autocast: bool = False, row_normalize: bool = False, truncation_up_file_path: str | None = None, truncation_down_file_path: str | None = None, **_)
Bases:
BaseResidualConnectionTruncated skip connection.
Applies a coarse-graining and reconstruction of input features using sparse projections to truncate high-frequency features.
Edge names and the edge-weight attribute are expected to be pre-resolved by
ProjectionCreatorand passed in directly. File-path loading is still supported as an alternative to the graph-based path.- Parameters:
graph (HeteroData, optional) – Graph containing the truncation subgraphs.
src_node_weight_attribute (str, optional) – Source-node attribute used as additional projection weights.
edge_weight_attribute (str, optional) – Edge attribute used as projection weights (default:
gauss_weight).truncation_config (dict, optional) – Configuration used to build or load the truncation projections.
truncation_up_edges_name (tuple[str, str, str], optional) – Pre-resolved
(src, relation, dst)edge type for the up-projection.truncation_down_edges_name (tuple[str, str, str], optional) – Pre-resolved
(src, relation, dst)edge type for the down-projection.data_node_name (str, default "data") – Name of the data nodes in
graph.autocast (bool, default False) – Whether to use automatic mixed precision for the projections.
row_normalize (bool, optional) – Normalize projection weights per target node so each row sums to 1.
truncation_up_file_path (str, optional) – Deprecated path to an
.npzfile for the up-projection matrix.truncation_down_file_path (str, optional) – Deprecated path to an
.npzfile for the down-projection matrix.
Examples
>>> # Graph-based path (edge names supplied by ProjectionCreator) >>> conn = TruncatedConnection( ... graph=graph, ... data_node_name="data", ... truncation_down_edges_name=("data", "to", "truncation"), ... truncation_up_edges_name=("truncation", "to", "data"), ... edge_weight_attribute="gauss_weight", ... ) >>> x = torch.randn(2, 4, 1, 40192, 44) # (batch, time, ens, nodes, features) >>> out = conn(x) >>> print(out.shape) torch.Size([2, 4, 1, 40192, 44])
>>> # File-based path >>> conn = TruncatedConnection( ... truncation_down_file_path="n320_to_o96.npz", ... truncation_up_file_path="o96_to_n320.npz", ... ) >>> x = torch.randn(2, 4, 1, 40192, 44) >>> out = conn(x) >>> print(out.shape) torch.Size([2, 4, 1, 40192, 44])
Configuration
Both connection types are configured under the residual key in the
model config. TruncatedConnection accepts sibling-class kwargs such
as step transparently, so switching between connection types requires
only changing _target_.
TruncatedConnection supports two modes, both via the
truncation_config key:
On-the-fly: the truncation subgraph is built at runtime from the main graph using a coarser
gridspecification.File-based: precomputed
.npzprojection matrices are loaded from disk.
Choose one mode per config; do not mix the two within the same
truncation_config block.
On-the-fly example:
model:
residual:
_target_: anemoi.models.layers.residual.TruncatedConnection
truncation_config:
grid: o32
num_nearest_neighbours: 3
sigma: 1.0
File-based example:
model:
residual:
_target_: anemoi.models.layers.residual.TruncatedConnection
truncation_config:
truncation_down_file_path: /path/to/O96-O32-grid-box-average.mat.npz
truncation_up_file_path: /path/to/O32-O96-grid-box-average.mat.npz
row_normalize: false
Note
The top-level truncation_up_file_path and
truncation_down_file_path kwargs are still accepted for backward
compatibility, but the recommended approach is to move them inside truncation_config.
Learnable residual (Ornstein)
Learnable residual connections introduce a trainable scaling parameter \(\alpha\) on the residual connection, giving a formulation equivalent to a discretized Ornstein–Uhlenbeck process:
With \(\alpha\) trainable and \(\alpha < 1\), errors in the state are contracted at each step rather than perfectly preserved. This bounds error growth during autoregressive integration.
Two variants are available, offering increasing degrees of spatial structure in the learnable parameters.
Scalar Ornstein Connection
A single learnable scalar \(\alpha_v\) per prognostic variable \(v\):
where \(\alpha_v \in (\alpha_{\text{buff}}, 1)\) is parameterized via a sigmoid. This is the simplest Ornstein variant – no spatial structure, just a per-variable damping factor.
- class anemoi.models.layers.residual.ScalarOrnsteinConnection(theta_init: float = 0.0, theta_buff: float = 0.0, theta_train: bool = True, regressors: list[str] | None = None, graph: HeteroData | None = None, statistics: dict | None = None, data_indices=None, dataset_name: str | None = None, **_)
Bases:
BaseResidualConnectionOrnstein residual with learnable scalars theta and mu.
residual(x) = (1 - theta) * x + mu + sum_i beta_i * f_iwhere theta is in (theta_buff, 1) and learned independently for each prognostic variable.
f_iare forcing variables listed inregressorsNo spatial or spectral structure.- Parameters:
theta_init (float) – Initial value for theta. If 0 and statistics are available, auto-initialized from tendency statistics.
theta_buff (float) – Lower bound buffer for theta. Theta is constrained to (theta_buff, 1).
theta_train (bool) – Whether theta is a trainable parameter.
regressors (list[str] | None) – Variable names to use as regressors.
Spectral Ornstein Connection
Spatially-varying \(\alpha\) and bias \(\mu\), defined as smooth functions on the sphere via spherical harmonic (SH) coefficients:
where \(s\) denotes the spatial location, \(\alpha_v(s)\),
\(\mu_v(s)\), and \(\beta_{i,v}(s)\) are reconstructed from
low-order SH coefficients (controlled by lmax), and \(f_i\) are
optional forcing regressors.
When truncate=True, a learnable spectral low-pass filter is applied
to the input fields before computing the residual. This removes
high-frequency content from the skip connection, forcing the model to
reconstruct fine-scale detail from scratch. An optional anti-aliasing
blend (anti_aliasing=True) smoothly mixes the filtered and
unfiltered fields.
- class anemoi.models.layers.residual.SpectralOrnsteinConnection(lmax: int = 2, grid: str = 'regular', theta_init: float = 0.0, theta_buff: float = 0.0, use_mean: bool = True, regressors: list[str] | None = None, truncate: bool = False, skip_truncate_variables: list[str] | None = None, anti_aliasing: bool = True, graph: HeteroData | None = None, statistics: dict | None = None, data_indices=None, dataset_name: str | None = None, **_)
Bases:
BaseResidualConnectionOrnstein residual with learnable spatially-varying theta and mu defined via spherical harmonics.
residual(x) = (1 - theta(s)) * x + mu(s) + sum_i beta_i(s) * f_iwhere theta/mu/beta_i are stored as
lmax x lmaxcomplex SH coefficients (per prognostic variable), and the spatial fields are obtained via inverse SHT.f_iare forcing variables listed inregressors.When
truncate=True, a learnable spectral low-pass filter is applied to the input fields before computing the residual. This truncates high-frequency content from the skip connection.- Parameters:
lmax (int) – Maximum spherical harmonic degree for the theta/mu coefficients.
grid (str) – Grid type:
"regular"for regular lat-lon,"octahedral"for octahedral reduced grids. Other types are not currently supported and will raise an error.theta_init (float) – Initial value for theta.
theta_buff (float) – Lower bound buffer for theta.
use_mean (bool) – Whether to include a the mean (mu) term.
regressors (list[str] | None) – Variable names to use as spatially-varying regressors.
truncate (bool) – If True, apply a learnable spectral low-pass filter to the input fields.
skip_truncate_variables (list[str] | None) – Variable names to exclude from spectral truncation (only used when
truncate=True).anti_aliasing (bool) – If True (and
truncate=True), use anti-aliasing blending in the filter.