Layers

Environment Variables

ANEMOI_INFERENCE_NUM_CHUNKS

This environment variable controls the number of chunks used in the Mapper and Processor during inference. Setting this variable allows the model to split large computations into a specified number of smaller chunks, reducing memory overhead. If not set, it falls back to the default value of, 1 i.e. no chunking. See pull request #46. For finer control, set the following environment variables:

ANEMOI_INFERENCE_NUM_CHUNKS_MAPPER

This environment variable controls the number of chunks used in the Mapper during inference.

ANEMOI_INFERENCE_NUM_CHUNKS_PROCESSOR

This environment variable controls the number of chunks used in the Processor during inference.

Mappers

class anemoi.models.layers.mapper.BaseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: DotDict, **kwargs)

Bases: Module, ABC

Base Mapper from source dimension to destination dimension.

Subclasses must implement pre_process() and post_process() methods specialized for their mapper type.

abstractmethod forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tensor | Tuple[Tensor, Tensor]

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.GraphTransformerBaseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, edge_dim: int, attn_channels: int | None = None, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict = None, shard_strategy: str = 'edges', graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)

Bases: BaseMapper, ABC

Graph Transformer Base Mapper from hidden -> data or data -> hidden.

forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.GraphTransformerForwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, edge_dim: int, attn_channels: int | None = None, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict = None, shard_strategy: str = 'edges', graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)

Bases: GraphTransformerBaseMapper

Graph Transformer Mapper from data -> hidden.

forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = True, **kwargs) Tuple[Tensor, Tensor]

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.GraphTransformerBackwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, edge_dim: int, attn_channels: int | None = None, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', initialise_data_extractor_zero: bool = False, cpu_offload: bool = False, layer_kernels: DotDict = None, shard_strategy: str = 'edges', graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)

Bases: GraphTransformerBaseMapper

Graph Transformer Mapper from hidden -> data.

class anemoi.models.layers.mapper.GNNBaseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, mlp_extra_layers: int, edge_dim: int, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict = None, **kwargs)

Bases: BaseMapper, ABC

Base for Graph Neural Network Mapper from hidden -> data or data -> hidden.

forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.GNNForwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, mlp_extra_layers: int, edge_dim: int, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)

Bases: GNNBaseMapper

Graph Neural Network Mapper data -> hidden.

class anemoi.models.layers.mapper.GNNBackwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, mlp_extra_layers: int, edge_dim: int, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)

Bases: GNNBaseMapper

Graph Neural Network Mapper from hidden -> data.

forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tensor

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.PointWiseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: dict | None = None)

Bases: BaseMapper, ABC

PointWise Mapper from hidden -> data or data -> hidden.

forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.PointWiseForwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: dict | None = None, **kwargs)

Bases: PointWiseMapper

PointWise Mapper from data -> hidden.

forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.PointWiseBackwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int, initialise_data_extractor_zero: bool = False, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: dict | None = None, **kwargs)

Bases: PointWiseMapper

PointWise Mapper from hidden -> data.

class anemoi.models.layers.mapper.TransformerBaseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, attn_channels: int | None = None, window_size: int | None = None, dropout_p: float = 0.0, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', attention_implementation: str = 'flash_attention', softcap: float | None = None, use_alibi_slopes: bool = False, use_rotary_embeddings: bool = False, cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)

Bases: BaseMapper, ABC

Transformer Base Mapper from hidden -> data or data -> hidden.

forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.TransformerForwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, attn_channels: int | None = None, qk_norm: bool = False, dropout_p: float = 0.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', attention_implementation: str = 'flash_attention', softcap: float = None, use_alibi_slopes: bool = False, cpu_offload: bool = False, window_size: int | None = None, use_rotary_embeddings: bool = False, layer_kernels: DotDict, **kwargs)

Bases: TransformerBaseMapper

Transformer Mapper from data -> hidden.

forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, **kwargs) Tuple[Tensor, Tensor]

Forward pass of the mapper.

Parameters:
  • x (PairTensor) – Input tensor pair (source, destination).

  • batch_size (int) – Batch size.

  • shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.

  • edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).

  • edge_index (Adj, optional) – Edge indices (required for graph-based mappers).

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • **kwargs (dict) – Additional keyword arguments passed to the mapper implementation.

Returns:

Mapper output tensor or tensor pair.

Return type:

Tensor or PairTensor

class anemoi.models.layers.mapper.TransformerBackwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, attn_channels: int | None = None, qk_norm: bool = False, dropout_p: float = 0.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', attention_implementation: str = 'flash_attention', softcap: float = None, use_alibi_slopes: bool = False, cpu_offload: bool = False, window_size: int | None = None, use_rotary_embeddings: bool = False, layer_kernels: DotDict, **kwargs)

Bases: TransformerBaseMapper

Graph Transformer Mapper from hidden -> data.

Processors

class anemoi.models.layers.processor.NoOpProcessor(**kwargs)

Bases: Module

No-op processor, used for ablations.

forward(x: Tensor, *args, **kwargs) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.processor.BaseProcessor(*, num_layers: int, num_channels: int, num_chunks: int, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: DotDict, **kwargs)

Bases: Module, ABC

Base Processor.

build_layers(layer_class, *layer_args, **layer_kwargs) None

Build Layers.

run_layers(data: tuple, *args, **kwargs) tuple

Run Layers with optional checkpoints around chunks.

forward(x: Tensor, *args, **kwargs) Tensor

Example forward pass.

class anemoi.models.layers.processor.PointWiseMLPProcessor(*, num_layers: int, num_channels: int, num_chunks: int, mlp_hidden_ratio: float, cpu_offload: bool = False, dropout_p: float = 0.0, layer_kernels: DotDict, **kwargs)

Bases: BaseProcessor

Point-wise MLP Processor.

forward(x: Tensor, batch_size: int, shard_info: GraphShardInfo, model_comm_group: ProcessGroup | None = None, *args, **kwargs) Tensor

Example forward pass.

class anemoi.models.layers.processor.TransformerProcessor(*, num_layers: int, num_channels: int, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, attn_channels: int | None = None, qk_norm=False, dropout_p: float = 0.0, attention_implementation: str = 'flash_attention', mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', softcap: float | None = None, use_alibi_slopes: bool = False, window_size: int | None = None, cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)

Bases: BaseProcessor

Transformer Processor.

forward(x: Tensor, batch_size: int, shard_info: GraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, *args, **kwargs) Tensor

Example forward pass.

class anemoi.models.layers.processor.GNNProcessor(*, num_channels: int, num_layers: int, num_chunks: int, mlp_extra_layers: int, edge_dim: int, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)

Bases: BaseProcessor

GNN Processor.

forward(x: Tensor, batch_size: int, shard_info: GraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, edges_are_dst_sorted: bool = True, *args, **kwargs) Tensor

Run the GNN processor.

Parameters:
  • x (Tensor) – Node features.

  • batch_size (int) – Batch size.

  • shard_info (GraphShardInfo) – Shard metadata for node and edge tensors.

  • edge_attr (Tensor) – Edge attributes.

  • edge_index (Adj) – Edge indices.

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • *args (tuple) – Additional positional arguments.

  • **kwargs (dict) – Additional keyword arguments passed to processor blocks.

Returns:

Processed node features.

Return type:

Tensor

class anemoi.models.layers.processor.GraphTransformerProcessor(*, num_layers: int, num_channels: int, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, edge_dim: int, attn_channels: int | None = None, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict, graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)

Bases: BaseProcessor

Processor.

forward(x: Tensor, batch_size: int, shard_info: GraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, edges_are_dst_sorted: bool = True, *args, **kwargs) Tensor

Run the graph-transformer processor.

Parameters:
  • x (Tensor) – Node features.

  • batch_size (int) – Batch size.

  • shard_info (GraphShardInfo) – Shard metadata for node and edge tensors.

  • edge_attr (Tensor) – Edge attributes.

  • edge_index (Adj) – Edge indices.

  • model_comm_group (ProcessGroup, optional) – Model communication group.

  • edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.

  • *args (tuple) – Additional positional arguments.

  • **kwargs (dict) – Additional keyword arguments passed to processor blocks.

Returns:

Processed node features.

Return type:

Tensor

Chunks

Blocks

class anemoi.models.layers.block.BaseBlock(**kwargs)

Bases: Module, ABC

Base class for network blocks.

abstractmethod forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: GraphShardInfo | BipartiteGraphShardInfo, batch_size: int, size: Tuple[int, int] | None = None, model_comm_group: ProcessGroup | None = None, **layer_kwargs) tuple[Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.PointWiseMLPProcessorBlock(*, num_channels: int, hidden_dim: int, layer_kernels: DotDict, dropout_p: float = 0.0)

Bases: BaseBlock

Point-wise block with MLPs.

forward(x: Tensor, shard_info: GraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None, **layer_kwargs) tuple[Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.TransformerProcessorBlock(*, num_channels: int, hidden_dim: int, num_heads: int, window_size: int | None, layer_kernels: DotDict, attn_channels: int | None = None, dropout_p: float = 0.0, qk_norm: bool = False, attention_implementation: str = 'flash_attention', mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', softcap: float | None = None, use_alibi_slopes: bool = False, use_rotary_embeddings: bool = False)

Bases: BaseBlock

Transformer block with MultiHeadSelfAttention and MLPs.

forward(x: Tensor, shard_info: GraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None, cond: Tensor | None = None, **layer_kwargs) tuple[Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.TransformerMapperBlock(*, num_channels: int, hidden_dim: int, num_heads: int, window_size: int | None, layer_kernels: DotDict, attn_channels: int | None = None, dropout_p: float = 0.0, qk_norm: bool = False, attention_implementation: str = 'flash_attention', mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', softcap: float | None = None, use_alibi_slopes: bool = False, use_rotary_embeddings: bool = False)

Bases: TransformerProcessorBlock

Transformer mapper block with MultiHeadCrossAttention and MLPs.

forward(x: Tuple[Tensor, Tensor | None], shard_info: BipartiteGraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None, cond: tuple[Tensor, Tensor] | None = None) tuple[Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.GraphConvBaseBlock(*, in_channels: int, out_channels: int, num_chunks: int, mlp_extra_layers: int = 0, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = True, layer_kernels: DotDict, edge_dim: int | None = None, **kwargs)

Bases: BaseBlock

Message passing block with MLPs for node embeddings.

abstractmethod forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: GraphShardInfo | BipartiteGraphShardInfo, model_comm_group: ProcessGroup | None = None, size: Tuple[int, int] | None = None, **layer_kwargs) tuple[Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.GraphConvProcessorBlock(*, in_channels: int, out_channels: int, num_chunks: int, mlp_extra_layers: int = 0, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = True, layer_kernels: DotDict, edge_dim: int | None = None, **kwargs)

Bases: GraphConvBaseBlock

forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: GraphShardInfo, model_comm_group: ProcessGroup | None = None, size: Tuple[int, int] | None = None, **layer_kwargs) tuple[Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.GraphConvMapperBlock(*, in_channels: int, out_channels: int, num_chunks: int, mlp_extra_layers: int = 0, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = True, layer_kernels: DotDict, edge_dim: int | None = None, **kwargs)

Bases: GraphConvBaseBlock

forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: BipartiteGraphShardInfo, model_comm_group: ProcessGroup | None = None, size: Tuple[int, int] | None = None, **layer_kwargs) tuple[Tensor, Tensor]

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.GraphTransformerBaseBlock(*, in_channels: int, hidden_dim: int, out_channels: int, num_heads: int, edge_dim: int, bias: bool = True, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = False, layer_kernels: DotDict, attn_channels: int | None = None, graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)

Bases: BaseBlock, ABC

Message passing block with MLPs for node embeddings.

shard_qkve_heads(query: Tensor, key: Tensor, value: Tensor, edges: Tensor, shard_info: BipartiteGraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None) tuple[Tensor, Tensor, Tensor, Tensor, list[int] | None]

Shards qkv and edges along head dimension using all_to_all_transpose.

shard_output_seq(out: Tensor, shard_info: BipartiteGraphShardInfo, head_shard_sizes: list[int] | None, batch_size: int, model_comm_group: ProcessGroup | None = None) Tensor

Shards Tensor sequence dimension using all_to_all_transpose.

abstractmethod forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: BipartiteGraphShardInfo, batch_size: int, size: int | tuple[int, int], model_comm_group: ProcessGroup | None = None, **kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.GraphTransformerMapperBlock(*, in_channels: int, hidden_dim: int, out_channels: int, num_heads: int, edge_dim: int, bias: bool = True, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = False, layer_kernels: DotDict, shard_strategy: str = 'edges', graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)

Bases: GraphTransformerBaseBlock

Graph Transformer Block for node embeddings.

forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: BipartiteGraphShardInfo, batch_size: int, size: int | tuple[int, int], model_comm_group: ProcessGroup | None = None, cond: tuple[Tensor, Tensor] | None = None, edges_are_dst_sorted: bool = True, **layer_kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.block.GraphTransformerProcessorBlock(*, in_channels: int, hidden_dim: int, out_channels: int, num_heads: int, edge_dim: int, bias: bool = True, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = False, layer_kernels: DotDict, graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)

Bases: GraphTransformerBaseBlock

Graph Transformer Block for node embeddings.

forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: GraphShardInfo, batch_size: int, size: int | tuple[int, int], model_comm_group: ProcessGroup | None = None, cond: Tensor | None = None, edges_are_dst_sorted: bool = True, **kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Graph

class anemoi.models.layers.graph.TrainableTensor(tensor_size: int, trainable_size: int)

Bases: Module

Trainable Tensor Module.

forward(x: Tensor, batch_size: int) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.graph.NamedNodesAttributes(trainable_parameters: dict[str, int], graph_data: HeteroData)

Bases: Module

Named Nodes Attributes information.

num_nodes

Number of nodes for each group of nodes.

Type:

dict[str, int]

attr_ndims

Total dimension of node attributes (non-trainable + trainable) for each group of nodes.

Type:

dict[str, int]

trainable_tensors

Dictionary of trainable tensors for each group of nodes.

Type:

nn.ModuleDict

get_coordinates(self, name: str) Tensor

Get the coordinates of a set of nodes.

forward(self, name: str, batch_size: int) Tensor

Get the node attributes to be passed trough the graph neural network.

define_fixed_attributes(graph_data: HeteroData, trainable_parameters: dict[str, int]) None

Define fixed attributes.

register_coordinates(name: str, node_coords: Tensor) None

Register coordinates.

get_coordinates(name: str) Tensor

Return original coordinates.

register_tensor(name: str, num_trainable_params: int) None

Register a trainable tensor.

forward(name: str, batch_size: int) Tensor

Returns the node attributes to be passed trough the graph neural network.

It includes both the coordinates and the trainable parameters.

Graph Providers

anemoi.models.layers.graph_provider.create_graph_provider(graph: HeteroData | None = None, edge_attributes: list[str] | None = None, src_size: int | None = None, dst_size: int | None = None, trainable_size: int = 0) BaseGraphProvider

Factory function to create appropriate graph provider.

Returns StaticGraphProvider if graph has edges, otherwise returns NoOpGraphProvider for edge-less architectures.

Parameters:
  • graph (HeteroData, optional) – Graph containing edges (for static mode)

  • edge_attributes (list[str], optional) – Edge attributes to use (for static mode)

  • src_size (int, optional) – Source grid size (for static mode)

  • dst_size (int, optional) – Destination grid size (for static mode)

  • trainable_size (int, optional) – Trainable tensor size, by default 0

Returns:

Appropriate graph provider instance

Return type:

BaseGraphProvider

class anemoi.models.layers.graph_provider.BaseGraphProvider(*args: Any, **kwargs: Any)

Bases: Module, ABC

Base class for graph edge providers.

Graph providers encapsulate the logic for supplying edge indices and attributes to mapper and processor layers. This allows for different strategies (static, dynamic, etc.).

abstractmethod get_edges(batch_size: int | None = None, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True) tuple[Tensor, Tensor | SparseTensor, list[int] | None] | Tensor

Get edge information.

Parameters:
  • batch_size (int, optional) – Number of times to expand the edge index (used by static mode)

  • src_coords (Tensor, optional) – Source node coordinates (used by dynamic mode for k-NN, radius graphs, etc.)

  • dst_coords (Tensor, optional) – Destination node coordinates (used by dynamic mode for k-NN, radius graphs, etc.)

  • model_comm_group (ProcessGroup, optional) – Model communication group

  • shard_edges (bool, optional) – Whether to shard edges, by default True

Returns:

For standard providers: (edge_attr, edge_index, edge_shard_sizes) tuple For sparse providers: sparse projection matrix

Return type:

Union[tuple[Tensor, Adj, Optional[ShardSizes]], Tensor]

abstract property edge_dim: int

Return the edge dimension.

property is_sparse: bool

Whether this provider returns sparse matrices.

class anemoi.models.layers.graph_provider.StaticGraphProvider(graph: HeteroData, edge_attributes: list[str], src_size: int, dst_size: int, trainable_size: int)

Bases: BaseGraphProvider

Provider for static graphs with fixed edge structure.

This provider owns all graph-related state including edge attributes, edge indices, and trainable parameters.

property edge_dim: int

Return the edge dimension.

get_edges(batch_size: int, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True, act_checkpoint: bool = True) tuple[Tensor, Tensor | SparseTensor, list[int] | None]

Get edge attributes and expanded edge index for static graph.

Parameters:
  • batch_size (int) – Number of times to expand the edge index

  • src_coords (Tensor, optional) – Source node coordinates (ignored for static graphs)

  • dst_coords (Tensor, optional) – Destination node coordinates (ignored for static graphs)

  • model_comm_group (ProcessGroup, optional) – Model communication group

  • shard_edges (bool, optional) – Whether to shard edges, by default True.

  • act_checkpoint (bool, optional) – Whether to use gradient checkpointing, by default True.

Returns:

Edge attributes, expanded edge index, and optional edge_shard_sizes. edge_shard_sizes is a list of per-rank partition sizes when shard_edges=True, otherwise None.

Return type:

tuple[Tensor, Adj, Optional[ShardSizes]]

class anemoi.models.layers.graph_provider.NoOpGraphProvider

Bases: BaseGraphProvider

Provider for edge-less architectures (e.g., Transformers).

Returns None for edges and has edge_dim=0. Used when the mapper/processor does not require graph structure (e.g., pure attention-based models).

property edge_dim: int

Return the edge dimension (0 for no edges).

get_edges(batch_size: int | None = None, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True) tuple[None, None, None]

Return None for edge attributes, edge index, and edge_shard_sizes.

Parameters:
  • batch_size (int, optional) – Unused

  • src_coords (Tensor, optional) – Unused

  • dst_coords (Tensor, optional) – Unused

  • model_comm_group (ProcessGroup, optional) – Unused

  • shard_edges (bool, optional) – Unused

Returns:

No edges

Return type:

tuple[None, None, None]

class anemoi.models.layers.graph_provider.DynamicGraphProvider(edge_dim: int)

Bases: BaseGraphProvider

Provider for dynamic graphs where edges are supplied at runtime.

Does not support trainable edge parameters.

Future implementation will support on-the-fly graph construction via build_graph() (e.g., k-NN graphs, radius graphs, adaptive connectivity).

property edge_dim: int

Return the edge dimension.

build_graph(src_nodes: Tensor, dst_nodes: Tensor, **kwargs) tuple[Tensor, Tensor | SparseTensor]

Build graph dynamically from source and destination nodes.

This method will be implemented in the future to support on-the-fly graph construction (e.g., k-NN graphs, radius graphs, etc.).

Parameters:
  • src_nodes (Tensor) – Source node features/positions

  • dst_nodes (Tensor) – Destination node features/positions

  • **kwargs – Additional parameters for graph construction algorithm

Returns:

Edge attributes and edge index

Return type:

tuple[Tensor, Adj]

Raises:

NotImplementedError – This functionality is not yet implemented

get_edges(batch_size: int | None = None, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True, act_checkpoint: bool = True) tuple[Tensor, Tensor | SparseTensor, list[int] | None]

Get dynamic edges constructed from node coordinates.

Calls build_graph() to construct edges on-the-fly using k-NN, radius graphs, etc.

Parameters:
  • batch_size (int, optional) – Batch size (currently unused, reserved for future implementation)

  • src_coords (Tensor, optional) – Source node coordinates

  • dst_coords (Tensor, optional) – Destination node coordinates

  • model_comm_group (ProcessGroup, optional) – Model communication group

  • shard_edges (bool, optional) – Whether to shard edges, by default True

  • act_checkpoint (bool, optional) – Whether to use gradient checkpointing, by default True.

Returns:

Edge attributes, edge index, and optional edge_shard_sizes.

Return type:

tuple[Tensor, Adj, Optional[ShardSizes]]

Raises:
class anemoi.models.layers.graph_provider.ProjectionGraphProvider(graph: HeteroData | None = None, edges_name: tuple[str, str, str] | None = None, edge_weight_attribute: str | None = None, src_node_weight_attribute: str | None = None, file_path: str | Path | None = None, row_normalize: bool = False)

Bases: BaseGraphProvider

Provider for sparse projection matrices.

Builds and stores sparse projection matrix from graph or file.

property edge_dim: int

Return projection matrix shape.

property is_sparse: bool

This provider returns sparse matrices.

get_edges(batch_size: int | None = None, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True, device: device | None = None) Tensor

Return the sparse projection matrix.

Parameters:
  • batch_size (int, optional) – Unused for sparse providers

  • src_coords (Tensor, optional) – Unused for sparse providers

  • dst_coords (Tensor, optional) – Unused for sparse providers

  • model_comm_group (ProcessGroup, optional) – Unused for sparse providers

  • shard_edges (bool, optional) – Unused for sparse providers

  • device (torch.device, optional) – Target device for matrix

Returns:

Sparse projection matrix

Return type:

Tensor

Conv

class anemoi.models.layers.conv.GraphConv(in_channels: int, out_channels: int, layer_kernels: DotDict, mlp_extra_layers: int = 0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', **kwargs)

Bases: MessagePassing

Message passing module for convolutional node and edge interactions.

forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, size: Tuple[int, int] | None = None)

Runs the forward pass of the module.

message(x_i: Tensor, x_j: Tensor, edge_attr: Tensor, dim_size: int | None = None) Tensor

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

aggregate(edges_new: Tensor, edge_index: Tensor | SparseTensor, dim_size: int | None = None) tuple[Tensor, Tensor]

Aggregates messages from neighbors as \(\bigoplus_{j \in \mathcal{N}(i)}\).

Takes in the output of message computation as first argument and any argument which was initially passed to propagate().

By default, this function will delegate its call to the underlying Aggregation module to reduce messages as specified in __init__() by the aggr argument.

class anemoi.models.layers.conv.GraphTransformerConv(out_channels: int, dropout: float = 0.0, **kwargs)

Bases: MessagePassing

Message passing part of graph transformer operator.

Adapted from ‘Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification’ (https://arxiv.org/abs/2009.03509)

forward(query: Tensor, key: Tensor, value: Tensor, edge_attr: Tensor | None, edge_index: Tensor | SparseTensor, size: Tuple[int, int] | None = None)

Runs the forward pass of the module.

message(heads: int, query_i: Tensor, key_j: Tensor, value_j: Tensor, edge_attr: Tensor | None, index: Tensor, ptr: Tensor | None, size_i: int | None) Tensor

Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in edge_index. This function can take any argument as input which was initially passed to propagate(). Furthermore, tensors passed to propagate() can be mapped to the respective nodes \(i\) and \(j\) by appending _i or _j to the variable name, .e.g. x_i and x_j.

Attention

class anemoi.models.layers.attention.MultiHeadSelfAttention(num_heads: int, embed_dim: int, layer_kernels: DotDict, attn_channels: int | None = None, qkv_bias: bool = False, qk_norm: bool = False, is_causal: bool = False, window_size: int | None = None, dropout_p: float = 0.0, attention_implementation: str = 'flash_attention', softcap: float | None = None, use_alibi_slopes: bool = False, use_rotary_embeddings: bool = False)

Bases: Module

Multi Head Self Attention Pytorch Layer

allows for three different attention implementations: - scaled dot product attention, see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - flash attention, see https://github.com/Dao-AILab/flash-attention

The config parameter “model.processor.attention_implementation” is used to control which attention implementation is used.

“scaled_dot_product_attention” (SDPA)

SDPA is a pytorch function, so it is easiest to use but the least performant. It runs on CPUs and GPUs.

“flash_attention”

Flash attention is optimised for efficient usage of the GPUs memory hierarchy. It loads smaller chunks into fast local memory, and fuses attention into a single kernel to reduce the passes through memory. It runs on Nvidia Ampere (e.g. A100) GPUs or newer and AMD MI200 GPUs or newer. Check the GitHub for the full requirements. You have to install flash attention yourself. If you are running on an x86 system, there are prebuilt wheels available on the GitHub repo. On an aarch64 system, you have to build flash attention from source.

forward(x: Tensor, grid_shard_sizes: GraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.attention.SDPAAttentionWrapper

Bases: Module

Wrapper for Pytorch scaled dot product attention To use this attention implementation: model.processor.attention_implementation=’scaled_dot_product_attention’

create_sliding_window_mask(B, H, Q_LEN, KV_LEN, window_size, device='cpu') Tensor

Create a mask for sliding window attention compatible with SDPA.

Parameters:
  • B (int) – Batch size

  • H (int) – Number of heads

  • Q_LEN (int) – Query sequence length

  • KV_LEN (int) – Key/value sequence length

  • window_size (tuple) – Tuple of (left_window, right_window). Use -1 for unlimited.

  • device (str) – Device for the mask tensor

Returns:

2D attention mask

Return type:

Tensor

forward(query, key, value, batch_size: int, causal=False, window_size=None, dropout_p=0.0, softcap=None, alibi_slopes=None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.attention.FlashAttentionWrapper(use_rotary_embeddings: bool = False, head_dim: int = None)

Bases: Module

Wrapper for Flash attention.

Either flash attn v2 or flash attn v3 (optimised for hoppers and newer), based on what is installed. flash attention v3 does not support rotary embeddings or alibi slopes. To use these features, you should downgrade to flash attention v2.

forward(query, key, value, batch_size: int, causal: bool = False, window_size: int | None = None, dropout_p: float = 0.0, softcap: float | None = None, alibi_slopes: Tensor = None)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class anemoi.models.layers.attention.MultiHeadCrossAttention(*args, **kwargs)

Bases: MultiHeadSelfAttention

Multi Head Cross Attention Pytorch Layer.

forward(x: Tuple[Tensor, Tensor], shard_info: BipartiteGraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

anemoi.models.layers.attention.get_alibi_slopes(num_heads: int) Tensor

Calculates linearly decreasing slopes for alibi attention.

Parameters:

num_heads (int) – number of attention heads

Returns:

aLiBi slopes

Return type:

Tensor

Multi-Layer Perceptron

class anemoi.models.layers.mlp.GatedMLPLayer(in_features: int, out_features: int, layer_kernels: DotDict, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'])

Bases: Module

Single gated feed-forward layer used by GLU variants.

forward(x: Tensor) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

anemoi.models.layers.mlp.build_feedforward_layer(in_features: int, out_features: int, layer_kernels: DotDict, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp') Module

Build one feed-forward layer module.

anemoi.models.layers.mlp.build_feedforward_modules(in_features: int, out_features: int, layer_kernels: DotDict, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp') list[Module]

Build one feed-forward layer as a flat list of modules.

class anemoi.models.layers.mlp.MLP(in_features: int, hidden_dim: int, out_features: int, layer_kernels: DotDict, n_extra_layers: int = 0, final_activation: bool = False, layer_norm: bool = True, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp')

Bases: Module

Multi-layer perceptron with optional checkpoint.

forward(x: Tensor, **layer_kwargs) Tensor

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Utils

anemoi.models.layers.utils.compute_mlp_hidden_dim(num_channels: int, mlp_hidden_ratio: float) int

Compute integer hidden dimension from a (possibly fractional) MLP ratio.

Parameters:
  • num_channels (int) – Base channel width.

  • mlp_hidden_ratio (float) – Multiplier used to derive hidden width.

Returns:

Rounded hidden width.

Return type:

int

class anemoi.models.layers.utils.CheckpointWrapper(module: Module)

Bases: Module

Wrapper for checkpointing a module.

forward(*args, **kwargs)

Define the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

anemoi.models.layers.utils.maybe_checkpoint(func, enabled: bool, *args, **kwargs)

Conditionally apply gradient checkpointing to a function.

Parameters:
  • func (callable) – The function to potentially wrap with checkpointing

  • enabled (bool) – Whether to apply gradient checkpointing

  • *args – Arguments to pass to the function

  • **kwargs – Arguments to pass to the function

Return type:

The result of calling func with the provided arguments

anemoi.models.layers.utils.load_layer_kernels(kernel_config: DotDict | None = None, instance: bool = True) Module'>, None)]

Load layer kernels from the config.

This function tries to load the layer kernels from the config. If the layer kernel is not supplied, it will fall back to the torch.nn implementation.

Parameters:
  • kernel_config (DotDict) – Kernel configuration, e.g. {“Linear”: {“_target_”: “torch.nn.Linear”}}

  • instance (bool) – If True, instantiate the kernels. If False, return the config. This is useful for testing purposes. Defaults to True.

Returns:

Container with layer factories.

Return type:

DotDict