Layers
Environment Variables
ANEMOI_INFERENCE_NUM_CHUNKS
This environment variable controls the number of chunks used in the Mapper and Processor during inference. Setting this variable allows the model to split large computations into a specified number of smaller chunks, reducing memory overhead. If not set, it falls back to the default value of, 1 i.e. no chunking. See pull request #46. For finer control, set the following environment variables:
ANEMOI_INFERENCE_NUM_CHUNKS_MAPPER
This environment variable controls the number of chunks used in the Mapper during inference.
ANEMOI_INFERENCE_NUM_CHUNKS_PROCESSOR
This environment variable controls the number of chunks used in the Processor during inference.
Mappers
- class anemoi.models.layers.mapper.BaseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: DotDict, **kwargs)
Bases:
Module,ABCBase Mapper from source dimension to destination dimension.
Subclasses must implement pre_process() and post_process() methods specialized for their mapper type.
- abstractmethod forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tensor | Tuple[Tensor, Tensor]
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.GraphTransformerBaseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, edge_dim: int, attn_channels: int | None = None, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict = None, shard_strategy: str = 'edges', graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)
Bases:
BaseMapper,ABCGraph Transformer Base Mapper from hidden -> data or data -> hidden.
- forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.GraphTransformerForwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, edge_dim: int, attn_channels: int | None = None, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict = None, shard_strategy: str = 'edges', graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)
Bases:
GraphTransformerBaseMapperGraph Transformer Mapper from data -> hidden.
- forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = True, **kwargs) Tuple[Tensor, Tensor]
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.GraphTransformerBackwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, edge_dim: int, attn_channels: int | None = None, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', initialise_data_extractor_zero: bool = False, cpu_offload: bool = False, layer_kernels: DotDict = None, shard_strategy: str = 'edges', graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)
Bases:
GraphTransformerBaseMapperGraph Transformer Mapper from hidden -> data.
- class anemoi.models.layers.mapper.GNNBaseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, mlp_extra_layers: int, edge_dim: int, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict = None, **kwargs)
Bases:
BaseMapper,ABCBase for Graph Neural Network Mapper from hidden -> data or data -> hidden.
- forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.GNNForwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, mlp_extra_layers: int, edge_dim: int, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)
Bases:
GNNBaseMapperGraph Neural Network Mapper data -> hidden.
- class anemoi.models.layers.mapper.GNNBackwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, mlp_extra_layers: int, edge_dim: int, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)
Bases:
GNNBaseMapperGraph Neural Network Mapper from hidden -> data.
- forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tensor
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.PointWiseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: dict | None = None)
Bases:
BaseMapper,ABCPointWise Mapper from hidden -> data or data -> hidden.
- forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.PointWiseForwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: dict | None = None, **kwargs)
Bases:
PointWiseMapperPointWise Mapper from data -> hidden.
- forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.PointWiseBackwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int, initialise_data_extractor_zero: bool = False, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: dict | None = None, **kwargs)
Bases:
PointWiseMapperPointWise Mapper from hidden -> data.
- class anemoi.models.layers.mapper.TransformerBaseMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, attn_channels: int | None = None, window_size: int | None = None, dropout_p: float = 0.0, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', attention_implementation: str = 'flash_attention', softcap: float | None = None, use_alibi_slopes: bool = False, use_rotary_embeddings: bool = False, cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)
Bases:
BaseMapper,ABCTransformer Base Mapper from hidden -> data or data -> hidden.
- forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, edges_are_dst_sorted: bool = True, **kwargs) Tuple[Tensor, Tensor]
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.TransformerForwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, attn_channels: int | None = None, qk_norm: bool = False, dropout_p: float = 0.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', attention_implementation: str = 'flash_attention', softcap: float = None, use_alibi_slopes: bool = False, cpu_offload: bool = False, window_size: int | None = None, use_rotary_embeddings: bool = False, layer_kernels: DotDict, **kwargs)
Bases:
TransformerBaseMapperTransformer Mapper from data -> hidden.
- forward(x: Tuple[Tensor, Tensor], batch_size: int, shard_info: BipartiteGraphShardInfo, edge_attr: Tensor | None = None, edge_index: Tensor | SparseTensor | None = None, model_comm_group: ProcessGroup | None = None, keep_x_dst_sharded: bool = False, **kwargs) Tuple[Tensor, Tensor]
Forward pass of the mapper.
- Parameters:
x (PairTensor) – Input tensor pair (source, destination).
batch_size (int) – Batch size.
shard_info (BipartiteGraphShardInfo) – Shard metadata. Each field is a list of per-rank partition sizes along the sharded dimension, or None if the tensor is replicated.
edge_attr (Tensor, optional) – Edge attributes (required for graph-based mappers).
edge_index (Adj, optional) – Edge indices (required for graph-based mappers).
model_comm_group (ProcessGroup, optional) – Model communication group.
keep_x_dst_sharded (bool, optional) – Whether to keep destination sharded, by default False.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
**kwargs (dict) – Additional keyword arguments passed to the mapper implementation.
- Returns:
Mapper output tensor or tensor pair.
- Return type:
Tensor or PairTensor
- class anemoi.models.layers.mapper.TransformerBackwardMapper(*, in_channels_src: int, in_channels_dst: int, hidden_dim: int, out_channels_dst: int | None = None, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, attn_channels: int | None = None, qk_norm: bool = False, dropout_p: float = 0.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', attention_implementation: str = 'flash_attention', softcap: float = None, use_alibi_slopes: bool = False, cpu_offload: bool = False, window_size: int | None = None, use_rotary_embeddings: bool = False, layer_kernels: DotDict, **kwargs)
Bases:
TransformerBaseMapperGraph Transformer Mapper from hidden -> data.
Processors
- class anemoi.models.layers.processor.NoOpProcessor(**kwargs)
Bases:
ModuleNo-op processor, used for ablations.
- forward(x: Tensor, *args, **kwargs) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.processor.BaseProcessor(*, num_layers: int, num_channels: int, num_chunks: int, cpu_offload: bool = False, gradient_checkpointing: bool = True, layer_kernels: DotDict, **kwargs)
Bases:
Module,ABCBase Processor.
- forward(x: Tensor, *args, **kwargs) Tensor
Example forward pass.
- class anemoi.models.layers.processor.PointWiseMLPProcessor(*, num_layers: int, num_channels: int, num_chunks: int, mlp_hidden_ratio: float, cpu_offload: bool = False, dropout_p: float = 0.0, layer_kernels: DotDict, **kwargs)
Bases:
BaseProcessorPoint-wise MLP Processor.
- forward(x: Tensor, batch_size: int, shard_info: GraphShardInfo, model_comm_group: ProcessGroup | None = None, *args, **kwargs) Tensor
Example forward pass.
- class anemoi.models.layers.processor.TransformerProcessor(*, num_layers: int, num_channels: int, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, attn_channels: int | None = None, qk_norm=False, dropout_p: float = 0.0, attention_implementation: str = 'flash_attention', mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', softcap: float | None = None, use_alibi_slopes: bool = False, window_size: int | None = None, cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)
Bases:
BaseProcessorTransformer Processor.
- class anemoi.models.layers.processor.GNNProcessor(*, num_channels: int, num_layers: int, num_chunks: int, mlp_extra_layers: int, edge_dim: int, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict, **kwargs)
Bases:
BaseProcessorGNN Processor.
- forward(x: Tensor, batch_size: int, shard_info: GraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, edges_are_dst_sorted: bool = True, *args, **kwargs) Tensor
Run the GNN processor.
- Parameters:
x (Tensor) – Node features.
batch_size (int) – Batch size.
shard_info (GraphShardInfo) – Shard metadata for node and edge tensors.
edge_attr (Tensor) – Edge attributes.
edge_index (Adj) – Edge indices.
model_comm_group (ProcessGroup, optional) – Model communication group.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
*args (tuple) – Additional positional arguments.
**kwargs (dict) – Additional keyword arguments passed to processor blocks.
- Returns:
Processed node features.
- Return type:
Tensor
- class anemoi.models.layers.processor.GraphTransformerProcessor(*, num_layers: int, num_channels: int, num_chunks: int, num_heads: int, mlp_hidden_ratio: float, edge_dim: int, attn_channels: int | None = None, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', cpu_offload: bool = False, layer_kernels: DotDict, graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)
Bases:
BaseProcessorProcessor.
- forward(x: Tensor, batch_size: int, shard_info: GraphShardInfo, edge_attr: Tensor, edge_index: Tensor | SparseTensor, model_comm_group: ProcessGroup | None = None, edges_are_dst_sorted: bool = True, *args, **kwargs) Tensor
Run the graph-transformer processor.
- Parameters:
x (Tensor) – Node features.
batch_size (int) – Batch size.
shard_info (GraphShardInfo) – Shard metadata for node and edge tensors.
edge_attr (Tensor) – Edge attributes.
edge_index (Adj) – Edge indices.
model_comm_group (ProcessGroup, optional) – Model communication group.
edges_are_dst_sorted (bool, optional) – Whether edge_index and edge_attr are already ordered by destination node. Edges from graph providers already are. Pass False for custom full-graph edges that are not ordered this way. If edges are already sharded, each rank is expected to already have the right edges for its local destination nodes.
*args (tuple) – Additional positional arguments.
**kwargs (dict) – Additional keyword arguments passed to processor blocks.
- Returns:
Processed node features.
- Return type:
Tensor
Chunks
Blocks
- class anemoi.models.layers.block.BaseBlock(**kwargs)
Bases:
Module,ABCBase class for network blocks.
- abstractmethod forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: GraphShardInfo | BipartiteGraphShardInfo, batch_size: int, size: Tuple[int, int] | None = None, model_comm_group: ProcessGroup | None = None, **layer_kwargs) tuple[Tensor, Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.PointWiseMLPProcessorBlock(*, num_channels: int, hidden_dim: int, layer_kernels: DotDict, dropout_p: float = 0.0)
Bases:
BaseBlockPoint-wise block with MLPs.
- forward(x: Tensor, shard_info: GraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None, **layer_kwargs) tuple[Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.TransformerProcessorBlock(*, num_channels: int, hidden_dim: int, num_heads: int, window_size: int | None, layer_kernels: DotDict, attn_channels: int | None = None, dropout_p: float = 0.0, qk_norm: bool = False, attention_implementation: str = 'flash_attention', mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', softcap: float | None = None, use_alibi_slopes: bool = False, use_rotary_embeddings: bool = False)
Bases:
BaseBlockTransformer block with MultiHeadSelfAttention and MLPs.
- forward(x: Tensor, shard_info: GraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None, cond: Tensor | None = None, **layer_kwargs) tuple[Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.TransformerMapperBlock(*, num_channels: int, hidden_dim: int, num_heads: int, window_size: int | None, layer_kernels: DotDict, attn_channels: int | None = None, dropout_p: float = 0.0, qk_norm: bool = False, attention_implementation: str = 'flash_attention', mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', softcap: float | None = None, use_alibi_slopes: bool = False, use_rotary_embeddings: bool = False)
Bases:
TransformerProcessorBlockTransformer mapper block with MultiHeadCrossAttention and MLPs.
- forward(x: Tuple[Tensor, Tensor | None], shard_info: BipartiteGraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None, cond: tuple[Tensor, Tensor] | None = None) tuple[Tensor, Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.GraphConvBaseBlock(*, in_channels: int, out_channels: int, num_chunks: int, mlp_extra_layers: int = 0, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = True, layer_kernels: DotDict, edge_dim: int | None = None, **kwargs)
Bases:
BaseBlockMessage passing block with MLPs for node embeddings.
- abstractmethod forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: GraphShardInfo | BipartiteGraphShardInfo, model_comm_group: ProcessGroup | None = None, size: Tuple[int, int] | None = None, **layer_kwargs) tuple[Tensor, Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.GraphConvProcessorBlock(*, in_channels: int, out_channels: int, num_chunks: int, mlp_extra_layers: int = 0, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = True, layer_kernels: DotDict, edge_dim: int | None = None, **kwargs)
Bases:
GraphConvBaseBlock- forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: GraphShardInfo, model_comm_group: ProcessGroup | None = None, size: Tuple[int, int] | None = None, **layer_kwargs) tuple[Tensor, Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.GraphConvMapperBlock(*, in_channels: int, out_channels: int, num_chunks: int, mlp_extra_layers: int = 0, mlp_hidden_ratio: float = 1.0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = True, layer_kernels: DotDict, edge_dim: int | None = None, **kwargs)
Bases:
GraphConvBaseBlock- forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: BipartiteGraphShardInfo, model_comm_group: ProcessGroup | None = None, size: Tuple[int, int] | None = None, **layer_kwargs) tuple[Tensor, Tensor]
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.GraphTransformerBaseBlock(*, in_channels: int, hidden_dim: int, out_channels: int, num_heads: int, edge_dim: int, bias: bool = True, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = False, layer_kernels: DotDict, attn_channels: int | None = None, graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)
-
Message passing block with MLPs for node embeddings.
- shard_qkve_heads(query: Tensor, key: Tensor, value: Tensor, edges: Tensor, shard_info: BipartiteGraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None) tuple[Tensor, Tensor, Tensor, Tensor, list[int] | None]
Shards qkv and edges along head dimension using all_to_all_transpose.
- shard_output_seq(out: Tensor, shard_info: BipartiteGraphShardInfo, head_shard_sizes: list[int] | None, batch_size: int, model_comm_group: ProcessGroup | None = None) Tensor
Shards Tensor sequence dimension using all_to_all_transpose.
- abstractmethod forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: BipartiteGraphShardInfo, batch_size: int, size: int | tuple[int, int], model_comm_group: ProcessGroup | None = None, **kwargs)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.GraphTransformerMapperBlock(*, in_channels: int, hidden_dim: int, out_channels: int, num_heads: int, edge_dim: int, bias: bool = True, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = False, layer_kernels: DotDict, shard_strategy: str = 'edges', graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)
Bases:
GraphTransformerBaseBlockGraph Transformer Block for node embeddings.
- forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: BipartiteGraphShardInfo, batch_size: int, size: int | tuple[int, int], model_comm_group: ProcessGroup | None = None, cond: tuple[Tensor, Tensor] | None = None, edges_are_dst_sorted: bool = True, **layer_kwargs)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.block.GraphTransformerProcessorBlock(*, in_channels: int, hidden_dim: int, out_channels: int, num_heads: int, edge_dim: int, bias: bool = True, qk_norm: bool = False, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', update_src_nodes: bool = False, layer_kernels: DotDict, graph_attention_backend: str = 'triton', edge_pre_mlp: bool = False, **kwargs)
Bases:
GraphTransformerBaseBlockGraph Transformer Block for node embeddings.
- forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, shard_info: GraphShardInfo, batch_size: int, size: int | tuple[int, int], model_comm_group: ProcessGroup | None = None, cond: Tensor | None = None, edges_are_dst_sorted: bool = True, **kwargs)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Graph
- class anemoi.models.layers.graph.TrainableTensor(tensor_size: int, trainable_size: int)
Bases:
ModuleTrainable Tensor Module.
- forward(x: Tensor, batch_size: int) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.graph.NamedNodesAttributes(trainable_parameters: dict[str, int], graph_data: HeteroData)
Bases:
ModuleNamed Nodes Attributes information.
- attr_ndims
Total dimension of node attributes (non-trainable + trainable) for each group of nodes.
- trainable_tensors
Dictionary of trainable tensors for each group of nodes.
- Type:
nn.ModuleDict
- forward(self, name: str, batch_size: int) Tensor
Get the node attributes to be passed trough the graph neural network.
Graph Providers
- anemoi.models.layers.graph_provider.create_graph_provider(graph: HeteroData | None = None, edge_attributes: list[str] | None = None, src_size: int | None = None, dst_size: int | None = None, trainable_size: int = 0) BaseGraphProvider
Factory function to create appropriate graph provider.
Returns StaticGraphProvider if graph has edges, otherwise returns NoOpGraphProvider for edge-less architectures.
- Parameters:
graph (HeteroData, optional) – Graph containing edges (for static mode)
edge_attributes (list[str], optional) – Edge attributes to use (for static mode)
src_size (int, optional) – Source grid size (for static mode)
dst_size (int, optional) – Destination grid size (for static mode)
trainable_size (int, optional) – Trainable tensor size, by default 0
- Returns:
Appropriate graph provider instance
- Return type:
- class anemoi.models.layers.graph_provider.BaseGraphProvider(*args: Any, **kwargs: Any)
Bases:
Module,ABCBase class for graph edge providers.
Graph providers encapsulate the logic for supplying edge indices and attributes to mapper and processor layers. This allows for different strategies (static, dynamic, etc.).
- abstractmethod get_edges(batch_size: int | None = None, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True) tuple[Tensor, Tensor | SparseTensor, list[int] | None] | Tensor
Get edge information.
- Parameters:
batch_size (int, optional) – Number of times to expand the edge index (used by static mode)
src_coords (Tensor, optional) – Source node coordinates (used by dynamic mode for k-NN, radius graphs, etc.)
dst_coords (Tensor, optional) – Destination node coordinates (used by dynamic mode for k-NN, radius graphs, etc.)
model_comm_group (ProcessGroup, optional) – Model communication group
shard_edges (bool, optional) – Whether to shard edges, by default True
- Returns:
For standard providers: (edge_attr, edge_index, edge_shard_sizes) tuple For sparse providers: sparse projection matrix
- Return type:
Union[tuple[Tensor, Adj, Optional[ShardSizes]], Tensor]
- class anemoi.models.layers.graph_provider.StaticGraphProvider(graph: HeteroData, edge_attributes: list[str], src_size: int, dst_size: int, trainable_size: int)
Bases:
BaseGraphProviderProvider for static graphs with fixed edge structure.
This provider owns all graph-related state including edge attributes, edge indices, and trainable parameters.
- get_edges(batch_size: int, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True, act_checkpoint: bool = True) tuple[Tensor, Tensor | SparseTensor, list[int] | None]
Get edge attributes and expanded edge index for static graph.
- Parameters:
batch_size (int) – Number of times to expand the edge index
src_coords (Tensor, optional) – Source node coordinates (ignored for static graphs)
dst_coords (Tensor, optional) – Destination node coordinates (ignored for static graphs)
model_comm_group (ProcessGroup, optional) – Model communication group
shard_edges (bool, optional) – Whether to shard edges, by default True.
act_checkpoint (bool, optional) – Whether to use gradient checkpointing, by default True.
- Returns:
Edge attributes, expanded edge index, and optional edge_shard_sizes. edge_shard_sizes is a list of per-rank partition sizes when shard_edges=True, otherwise None.
- Return type:
tuple[Tensor, Adj, Optional[ShardSizes]]
- class anemoi.models.layers.graph_provider.NoOpGraphProvider
Bases:
BaseGraphProviderProvider for edge-less architectures (e.g., Transformers).
Returns None for edges and has edge_dim=0. Used when the mapper/processor does not require graph structure (e.g., pure attention-based models).
- class anemoi.models.layers.graph_provider.DynamicGraphProvider(edge_dim: int)
Bases:
BaseGraphProviderProvider for dynamic graphs where edges are supplied at runtime.
Does not support trainable edge parameters.
Future implementation will support on-the-fly graph construction via build_graph() (e.g., k-NN graphs, radius graphs, adaptive connectivity).
- build_graph(src_nodes: Tensor, dst_nodes: Tensor, **kwargs) tuple[Tensor, Tensor | SparseTensor]
Build graph dynamically from source and destination nodes.
This method will be implemented in the future to support on-the-fly graph construction (e.g., k-NN graphs, radius graphs, etc.).
- Parameters:
src_nodes (Tensor) – Source node features/positions
dst_nodes (Tensor) – Destination node features/positions
**kwargs – Additional parameters for graph construction algorithm
- Returns:
Edge attributes and edge index
- Return type:
tuple[Tensor, Adj]
- Raises:
NotImplementedError – This functionality is not yet implemented
- get_edges(batch_size: int | None = None, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True, act_checkpoint: bool = True) tuple[Tensor, Tensor | SparseTensor, list[int] | None]
Get dynamic edges constructed from node coordinates.
Calls build_graph() to construct edges on-the-fly using k-NN, radius graphs, etc.
- Parameters:
batch_size (int, optional) – Batch size (currently unused, reserved for future implementation)
src_coords (Tensor, optional) – Source node coordinates
dst_coords (Tensor, optional) – Destination node coordinates
model_comm_group (ProcessGroup, optional) – Model communication group
shard_edges (bool, optional) – Whether to shard edges, by default True
act_checkpoint (bool, optional) – Whether to use gradient checkpointing, by default True.
- Returns:
Edge attributes, edge index, and optional edge_shard_sizes.
- Return type:
tuple[Tensor, Adj, Optional[ShardSizes]]
- Raises:
ValueError – If coordinates are not provided
NotImplementedError – If build_graph() is not yet implemented
- class anemoi.models.layers.graph_provider.ProjectionGraphProvider(graph: HeteroData | None = None, edges_name: tuple[str, str, str] | None = None, edge_weight_attribute: str | None = None, src_node_weight_attribute: str | None = None, file_path: str | Path | None = None, row_normalize: bool = False)
Bases:
BaseGraphProviderProvider for sparse projection matrices.
Builds and stores sparse projection matrix from graph or file.
- get_edges(batch_size: int | None = None, src_coords: Tensor | None = None, dst_coords: Tensor | None = None, model_comm_group: ProcessGroup | None = None, shard_edges: bool = True, device: device | None = None) Tensor
Return the sparse projection matrix.
- Parameters:
batch_size (int, optional) – Unused for sparse providers
src_coords (Tensor, optional) – Unused for sparse providers
dst_coords (Tensor, optional) – Unused for sparse providers
model_comm_group (ProcessGroup, optional) – Unused for sparse providers
shard_edges (bool, optional) – Unused for sparse providers
device (torch.device, optional) – Target device for matrix
- Returns:
Sparse projection matrix
- Return type:
Tensor
Conv
- class anemoi.models.layers.conv.GraphConv(in_channels: int, out_channels: int, layer_kernels: DotDict, mlp_extra_layers: int = 0, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp', **kwargs)
Bases:
MessagePassingMessage passing module for convolutional node and edge interactions.
- forward(x: Tuple[Tensor, Tensor | None], edge_attr: Tensor, edge_index: Tensor | SparseTensor, size: Tuple[int, int] | None = None)
Runs the forward pass of the module.
- message(x_i: Tensor, x_j: Tensor, edge_attr: Tensor, dim_size: int | None = None) Tensor
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.
- aggregate(edges_new: Tensor, edge_index: Tensor | SparseTensor, dim_size: int | None = None) tuple[Tensor, Tensor]
Aggregates messages from neighbors as \(\bigoplus_{j \in \mathcal{N}(i)}\).
Takes in the output of message computation as first argument and any argument which was initially passed to
propagate().By default, this function will delegate its call to the underlying
Aggregationmodule to reduce messages as specified in__init__()by theaggrargument.
- class anemoi.models.layers.conv.GraphTransformerConv(out_channels: int, dropout: float = 0.0, **kwargs)
Bases:
MessagePassingMessage passing part of graph transformer operator.
Adapted from ‘Masked Label Prediction: Unified Message Passing Model for Semi-Supervised Classification’ (https://arxiv.org/abs/2009.03509)
- forward(query: Tensor, key: Tensor, value: Tensor, edge_attr: Tensor | None, edge_index: Tensor | SparseTensor, size: Tuple[int, int] | None = None)
Runs the forward pass of the module.
- message(heads: int, query_i: Tensor, key_j: Tensor, value_j: Tensor, edge_attr: Tensor | None, index: Tensor, ptr: Tensor | None, size_i: int | None) Tensor
Constructs messages from node \(j\) to node \(i\) in analogy to \(\phi_{\mathbf{\Theta}}\) for each edge in
edge_index. This function can take any argument as input which was initially passed topropagate(). Furthermore, tensors passed topropagate()can be mapped to the respective nodes \(i\) and \(j\) by appending_ior_jto the variable name, .e.g.x_iandx_j.
Attention
- class anemoi.models.layers.attention.MultiHeadSelfAttention(num_heads: int, embed_dim: int, layer_kernels: DotDict, attn_channels: int | None = None, qkv_bias: bool = False, qk_norm: bool = False, is_causal: bool = False, window_size: int | None = None, dropout_p: float = 0.0, attention_implementation: str = 'flash_attention', softcap: float | None = None, use_alibi_slopes: bool = False, use_rotary_embeddings: bool = False)
Bases:
ModuleMulti Head Self Attention Pytorch Layer
allows for three different attention implementations: - scaled dot product attention, see https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html - flash attention, see https://github.com/Dao-AILab/flash-attention
The config parameter “model.processor.attention_implementation” is used to control which attention implementation is used.
- “scaled_dot_product_attention” (SDPA)
SDPA is a pytorch function, so it is easiest to use but the least performant. It runs on CPUs and GPUs.
- “flash_attention”
Flash attention is optimised for efficient usage of the GPUs memory hierarchy. It loads smaller chunks into fast local memory, and fuses attention into a single kernel to reduce the passes through memory. It runs on Nvidia Ampere (e.g. A100) GPUs or newer and AMD MI200 GPUs or newer. Check the GitHub for the full requirements. You have to install flash attention yourself. If you are running on an x86 system, there are prebuilt wheels available on the GitHub repo. On an aarch64 system, you have to build flash attention from source.
- forward(x: Tensor, grid_shard_sizes: GraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.attention.SDPAAttentionWrapper
Bases:
ModuleWrapper for Pytorch scaled dot product attention To use this attention implementation: model.processor.attention_implementation=’scaled_dot_product_attention’
- create_sliding_window_mask(B, H, Q_LEN, KV_LEN, window_size, device='cpu') Tensor
Create a mask for sliding window attention compatible with SDPA.
- Parameters:
- Returns:
2D attention mask
- Return type:
Tensor
- forward(query, key, value, batch_size: int, causal=False, window_size=None, dropout_p=0.0, softcap=None, alibi_slopes=None)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.attention.FlashAttentionWrapper(use_rotary_embeddings: bool = False, head_dim: int = None)
Bases:
ModuleWrapper for Flash attention.
Either flash attn v2 or flash attn v3 (optimised for hoppers and newer), based on what is installed. flash attention v3 does not support rotary embeddings or alibi slopes. To use these features, you should downgrade to flash attention v2.
- forward(query, key, value, batch_size: int, causal: bool = False, window_size: int | None = None, dropout_p: float = 0.0, softcap: float | None = None, alibi_slopes: Tensor = None)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- class anemoi.models.layers.attention.MultiHeadCrossAttention(*args, **kwargs)
Bases:
MultiHeadSelfAttentionMulti Head Cross Attention Pytorch Layer.
- forward(x: Tuple[Tensor, Tensor], shard_info: BipartiteGraphShardInfo, batch_size: int, model_comm_group: ProcessGroup | None = None) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Multi-Layer Perceptron
- class anemoi.models.layers.mlp.GatedMLPLayer(in_features: int, out_features: int, layer_kernels: DotDict, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'])
Bases:
ModuleSingle gated feed-forward layer used by GLU variants.
- forward(x: Tensor) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- anemoi.models.layers.mlp.build_feedforward_layer(in_features: int, out_features: int, layer_kernels: DotDict, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp') Module
Build one feed-forward layer module.
- anemoi.models.layers.mlp.build_feedforward_modules(in_features: int, out_features: int, layer_kernels: DotDict, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp') list[Module]
Build one feed-forward layer as a flat list of modules.
- class anemoi.models.layers.mlp.MLP(in_features: int, hidden_dim: int, out_features: int, layer_kernels: DotDict, n_extra_layers: int = 0, final_activation: bool = False, layer_norm: bool = True, mlp_implementation: Literal['mlp', 'glu', 'swiglu', 'geglu', 'reglu'] = 'mlp')
Bases:
ModuleMulti-layer perceptron with optional checkpoint.
- forward(x: Tensor, **layer_kwargs) Tensor
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Utils
Compute integer hidden dimension from a (possibly fractional) MLP ratio.
- class anemoi.models.layers.utils.CheckpointWrapper(module: Module)
Bases:
ModuleWrapper for checkpointing a module.
- forward(*args, **kwargs)
Define the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
- anemoi.models.layers.utils.maybe_checkpoint(func, enabled: bool, *args, **kwargs)
Conditionally apply gradient checkpointing to a function.
- Parameters:
func (callable) – The function to potentially wrap with checkpointing
enabled (bool) – Whether to apply gradient checkpointing
*args – Arguments to pass to the function
**kwargs – Arguments to pass to the function
- Return type:
The result of calling func with the provided arguments
- anemoi.models.layers.utils.load_layer_kernels(kernel_config: DotDict | None = None, instance: bool = True) Module'>, None)]
Load layer kernels from the config.
This function tries to load the layer kernels from the config. If the layer kernel is not supplied, it will fall back to the torch.nn implementation.
- Parameters:
kernel_config (DotDict) – Kernel configuration, e.g. {“Linear”: {“_target_”: “torch.nn.Linear”}}
instance (bool) – If True, instantiate the kernels. If False, return the config. This is useful for testing purposes. Defaults to True.
- Returns:
Container with layer factories.
- Return type:
DotDict