Subpackages

anemoi.inference.checkpoint.get_multi_dataset_metadata(metadata: dict, supporting_arrays: dict, base_class=<class 'anemoi.inference.metadata.Metadata'>) dict[str, SingleDatasetMetadata | MultiDatasetMetadata]

Metadata objects for all datasets in the raw metadata, as a mapping from dataset name to Metadata object. For legacy checkpoints, the dataset name defaults to data

class anemoi.inference.checkpoint.Checkpoint(source: str | ~anemoi.inference.metadata.Metadata | dict[~typing.Literal['huggingface'], str | dict], *, metadata_base: type[~anemoi.inference.metadata.Metadata] = <class 'anemoi.inference.metadata.Metadata'>, patch_metadata: dict[str, ~typing.Any] | None = None)

Bases: object

Represents an inference checkpoint.

property path: str

Get the path to the checkpoint.

property multi_dataset: bool

Check if the checkpoint is a multi-dataset checkpoint.

property multi_dataset_metadata: dict[str, SingleDatasetMetadata | MultiDatasetMetadata]

Metadata for all datasets in the checkpoint, as a mapping from dataset name to Metadata object. For legacy checkpoints, the dataset name defaults to data

property timestep: Any

Get the timestep.

property lagged: list[timedelta]

Return the list of steps for the multi_step_input fields.

property multi_step_input: int

Get the multi-step input.

property multi_step_output: int

Get the multi-step output.

property input_explicit_times: Any

Get the input explicit times from metadata.

property target_explicit_times: Any

Get the target explicit times.

property data_frequency: Any

Get the data frequency.

property precision: Any

Get the precision.

property sources: list[SourceCheckpoint]

Get the sources.

report_error() None

Report an error.

validate_environment(*, all_packages: bool = False, on_difference: Literal['warn', 'error', 'ignore', 'return'] = 'warn', exempt_packages: list[str] | None = None) bool | str

Validate the environment.

Parameters:
  • all_packages (bool, optional) – Check all packages in the environment (True) or just anemoi’s (False), by default False.

  • on_difference (Literal['warn', 'error', 'ignore', 'return'], optional) – What to do on difference, by default “warn”

  • exempt_packages (list[str], optional) – List of packages to exempt from the check, by default EXEMPT_PACKAGES

Returns:

boolean if on_difference is not ‘return’, otherwise formatted text of the differences True if environment is valid, False otherwise

Return type:

Union[bool, str]

provenance_training() Any

Get the provenance of the training.

Returns:

The provenance of the training.

Return type:

Any

update_metadata_from_zarr() tuple[dict[str, Any], dict[str, Any]]

Get new metadata and supporting array dictionaries from the original training dataset. Useful if the training dataset metadata has been patched/updated. Updates the internal _raw_metadata and returns a new (metadata, supporting_arrays).

class anemoi.inference.checkpoint.SourceCheckpoint(owner: Checkpoint, metadata: Any)

Bases: Checkpoint

A checkpoint that represents a source.

property operational_config: dict[str, Any]
anemoi.inference.checks.check_data(title: str, data: FieldList, variables: list[str], dates: list[datetime], metadata: Metadata) None

Check if the data matches the expected number of fields based on variables and dates.

Parameters:
  • title (str) – The title for the data check.

  • data (FieldList) – The data to be checked.

  • variables (List[str]) – The list of variable names.

  • dates (List[datetime.datetime]) – The list of dates.

  • metadata (Metadata) – The metadata object associated with the dataset that produced the data.

Raises:

ValueError – If the data does not match the expected number of fields.

class anemoi.inference.context.Context

Bases: ABC

Represents the context of the inference.

allow_nans = None
use_grib_paramid = False
verbosity = 0
development_hacks: dict[str, Any] = {}
reference_date = None
time_step = None
lead_time = None
output_frequency: int | None = None
write_initial_state: bool = True
abstract property checkpoint: Checkpoint

Returns the checkpoint used for the inference.

class anemoi.inference.decorators.main_argument(name: str)

Bases: object

Decorator to set the main argument of a class. Only for classes with a ‘context’ argument.

For example: ``` @main_argument(“path”) class GribOutput

def __init__(context, encoding=None, path=None, archive_requests=None):

output = GribOutput(context, “out.grib”) ` So in the config we can have: ` output:

grib: out.grib

` meaning the same as ` output:

grib:

path: out.grib

```

class anemoi.inference.decorators.ensure_path(arg: str, is_dir: bool = False, create: bool = True, must_exist: bool = False, unique: bool = True)

Bases: object

Decorator to ensure a path argument is a Path object and optionally exists.

If is_dir is True, the path is treated as a directory, if not for files, the parent directory is treated as a directory. If must_exist is True, the directory must exist. If create is True, the directory will be created if it doesn’t exist. If ‘unique’ is True, the same path cannot be reused between multiple decorated classes.

For example: ``` @ensure_path(“dir”, create=True) class GribOutput

def __init__(context, dir=None, archive_requests=None):

class anemoi.inference.decorators.ensure_dir(arg: str, create: bool = True, must_exist: bool = False, unique: bool = True)

Bases: ensure_path

Decorator to ensure a directory path argument is a Path object and optionally exists.

If must_exist is True, the directory must exist. If create is True, the directory will be created if it doesn’t exist. If ‘unique’ is True, the same path cannot be reused between multiple decorated classes.

For example: ``` @ensure_dir(“dir”, create=True) class PlotOutput

def __init__(context, dir=None, …):

class anemoi.inference.decorators.format_dataset_name(arg: str)

Bases: object

Decorator to format a string argument with the dataset name. Substitutes {dataset} or {dataset_name} in the argument with the dataset name. Can only be used for classes that take metadata. For example: ``` output:

grib: output-{dataset}.grib

```

anemoi.inference.device.get_available_device() torch.device

Get the available device for PyTorch.

Returns:

The available device, either ‘cuda’, ‘mps’, or ‘cpu’.

Return type:

torch.device

class anemoi.inference.forcings.Forcings(context: TensorHandler)

Bases: ABC

Represents a forcings provider for the model.

mask: ndarray[tuple[Any, ...], dtype[Any]]
variables: list[str]
abstractmethod load_forcings_array(dates: list[str | datetime | int], current_state: dict[str, Any]) ndarray[tuple[Any, ...], dtype[Any]]

Load the forcings for the given dates.

Parameters:
  • dates (List[Date]) – The list of dates for which to load the forcings.

  • current_state (State) – The current state of the model.

Returns:

The loaded forcings as a numpy array.

Return type:

FloatArray

class anemoi.inference.forcings.ComputedForcings(context: TensorHandler, variables: list[str], mask: Any)

Bases: Forcings

Compute forcings like cos_julian_day or insolation.

trace_name = 'computed'
load_forcings_array(dates: list[str | datetime | int], current_state: dict[str, Any]) ndarray[tuple[Any, ...], dtype[Any]]

Load the computed forcings for the given dates.

Parameters:
  • dates (List[Date]) – The list of dates for which to load the forcings.

  • current_state (State) – The current state of the model.

Returns:

The loaded forcings as a numpy array.

Return type:

FloatArray

class anemoi.inference.forcings.CoupledForcings(context: TensorHandler, input: Input, variables: list[str], mask: ndarray[tuple[Any, ...], dtype[Any]])

Bases: Forcings

Retrieve forcings from the input.

property trace_name: str

Return the trace name of the input.

load_forcings_array(dates: list[str | datetime | int], current_state: dict[str, Any]) ndarray[tuple[Any, ...], dtype[Any]]

Load the forcings for the given dates.

Parameters:
  • dates (List[Any]) – The list of dates for which to load the forcings.

  • current_state (State) – The current state of the model.

Returns:

The loaded forcings as a numpy array.

Return type:

FloatArray

class anemoi.inference.forcings.ConstantForcings(context: TensorHandler, input: Input, variables: list[str], mask: ndarray[tuple[Any, ...], dtype[Any]])

Bases: CoupledForcings

class anemoi.inference.forcings.BoundaryForcings(context: TensorHandler, input: DatasetInput, variables: list[str], variables_mask: ndarray[tuple[Any, ...], dtype[Any]])

Bases: Forcings

Retrieve boundary forcings from the input.

load_forcings_array(dates: list[str | datetime | int], current_state: dict[str, Any]) ndarray[tuple[Any, ...], dtype[Any]]

Load the boundary forcings for the given dates.

Parameters:
  • dates (List[Date]) – The list of dates for which to load the forcings.

  • current_state (State) – The current state of the model.

Returns:

The loaded forcings as a numpy array.

Return type:

FloatArray

class anemoi.inference.input.Input(context: Context, metadata: Metadata, *, variables: list[str] | None = None, pre_processors: list[str | dict[str, Any]] | None = None, purpose: str | None = None)

Bases: ABC

Abstract base class for input handling.

trace_name = '????'
property pre_processors: list[Processor]

Return pre-processors.

pre_process(x: Any) Any

Run pre-processors.

Parameters:

x (Any) – input to pre-process

Returns:

Pre-processed input

Return type:

Any

abstractmethod create_input_state(*, date: str | datetime | int | None, **kwargs) dict[str, Any]

Create the input state dictionary.

Parameters:
  • date (Optional[Date]) – The date for which to create the input state.

  • **kwargs (Any) – Additional keyword arguments.

Returns:

The input state dictionary.

Return type:

State

abstractmethod load_forcings_state(*, dates: list[str | datetime | int], current_state: dict[str, Any]) dict[str, Any]

Load forcings (constant and dynamic).

Parameters:
  • dates (List[Date]) – The list of dates for which to load the forcings.

  • current_state (State) – The current state of the model.

Returns:

The updated state with the loaded forcings.

Return type:

State

input_variables() list[str]

Return the list of input variables.

Returns:

The list of input variables.

Return type:

List[str]

patch_data_request(request: Any) Any

Patch the data request.

Uses both the context and input preprocessors.

Parameters:

request (Any) – The data request.

Returns:

The patched data request.

Return type:

Any

set_private_attributes(state: dict[str, Any], value: Any) None

Provide a way to a subclass to set private attributes in the state dictionary, that may be needed by the output object.

Parameters:
  • state (State) – The state dictionary.

  • value (Any) – The value to set.

class anemoi.inference.lazy.LazyModule(module_name: str)

Bases: object

Defer loading of a module until attribute access.

anemoi.inference.legacy.warn(func: Callable[[...], Any]) Callable[[...], Any]

Decorator to issue a warning when using legacy functions.

Parameters:

func (function) – The legacy function to be wrapped.

Returns:

The wrapped function with a warning.

Return type:

function

class anemoi.inference.legacy.LegacyMixin(*args, **kwargs)

Bases: MetadataProtocol

class anemoi.inference.metadata.Metadata(metadata: dict[str, Any], supporting_arrays: dict[str, ndarray[tuple[Any, ...], dtype[Any]]] = {})

Bases: LegacyMixin

Base Metadata class.

multi_dataset = False
dataset_name = 'data'
property target_explicit_times: Any

Return the target explicit times from the training configuration.

property input_explicit_times: Any

Return the input explicit times from the training configuration.

property data_frequency: Any

Get the data frequency.

print_indices(print=<bound method Logger.info of <Logger anemoi.inference.metadata (WARNING)>>) None

Print data and model indices for debugging purposes.

property lagged: list[timedelta]

Return the list of steps for the multi_step_input fields.

property timestep: timedelta

Model time stepping timestep.

property precision: str | int

Return the precision of the model (bits per float).

property input_shape: tuple[int, int, int, int]
property output_shape: tuple[int, int, int, int]
property variable_to_input_tensor_index: MappingProxyType

Return the mapping between variable name and input tensor index.

property variable_to_output_tensor_index: MappingProxyType

Return the mapping between variable name and output tensor index.

property input_tensor_index_to_variable: MappingProxyType

Return the mapping between input tensor index and variable name.

property output_tensor_index_to_variable: MappingProxyType

Return the mapping between output tensor index and variable name.

property number_of_grid_points: int

Return the number of grid points per fields.

property number_of_input_features: int

Return the number of input features.

property model_computed_variables: tuple

The initial conditions variables that need to be computed and not retrieved.

property multi_step_input: int

Number of past steps needed for the initial conditions tensor.

property multi_step_output: int

Number of future steps predicted by single model forward.

property prognostic_output_mask: ndarray[tuple[Any, ...], dtype[Any]]

Return the prognostic output mask.

property prognostic_input_mask: ndarray[tuple[Any, ...], dtype[Any]]

Return the prognostic input mask.

property computed_time_dependent_forcings: tuple[ndarray, list]

Return the indices and names of the computed forcings that are not constant in time.

Deprecated since version 0.6.4: This will be removed in 0.7.0. Use select_variables_and_mask instead.

property computed_constant_forcings: tuple[ndarray[tuple[Any, ...], dtype[Any]], list[str]]

Return the indices and names of the computed forcings that are constant in time.

Deprecated since version 0.6.4: This will be removed in 0.7.0. Use select_variables_and_mask instead.

has_supporting_array(name: str) bool

Check if the metadata has a supporting array with the given name.

Parameters:

name (str) – The name of the supporting array.

Returns:

True if the supporting array exists, False otherwise.

Return type:

bool

property variables: tuple

Return the variables as found in the training dataset.

property variables_metadata: dict[str, Any]

Return the variables and their metadata as found in the training dataset.

property diagnostic_variables: list

Variables that are marked as diagnostic.

Deprecated since version 0.6.4: This will be removed in 0.7.0. Use select_variables instead.

property prognostic_variables: list

Variables that are marked as prognostic.

Deprecated since version 0.6.4: This will be removed in 0.7.0. Use select_variables instead.

property index_to_variable: MappingProxyType

Return a mapping from index to variable name.

property typed_variables: dict[str, Variable]

Returns a strongly typed variables.

property accumulations: list

Return the indices of the variables that are accumulations.

name_fields(fields: FieldList, namer: Callable[[...], str] | None = None) FieldList

Name fields using the provided namer.

Parameters:
  • fields (FieldList) – The fields to name.

  • namer (callable, optional) – The namer function, by default None.

Returns:

The named fields.

Return type:

FieldList

sort_by_name(fields: FieldList, *args: Any, namer: Callable[[...], Any] | None = None, **kwargs: Any) FieldList

Sort fields by name.

Parameters:
  • fields (ekd.FieldList) – The fields to sort.

  • args (Any) – Additional arguments.

  • namer (callable, optional) – The namer function, by default None.

  • kwargs (Any) – Additional keyword arguments.

Returns:

The sorted fields.

Return type:

ekd.FieldList

default_namer(*args: Any, **kwargs: Any) Callable[[...], str]

Return a callable that can be used to name earthkit-data fields.

Parameters:
  • args (Any) – Additional arguments.

  • kwargs (Any) – Additional keyword arguments.

Returns:

The namer function.

Return type:

Callable

property grid: str | None

Return the grid information.

property area: str | None

Return the area information.

select_variables(*, include: list[str] | None = None, exclude: list[str] | None = None, has_mars_requests: bool = False) list[str]

Get variables from input.

Parameters:
  • include (List[str]) – Categories to include.

  • exclude (List[str]) – Categories to exclude.

  • has_mars_requests (bool) – If True, only include variables that have MARS requests.

Returns:

The list of variables.

Return type:

List[str]

variables_mask(*, variables: list[str]) ndarray[tuple[Any, ...], dtype[Any]]
select_variables_and_masks(*, include: list[str] | None = None, exclude: list[str] | None = None) tuple[list[str], ndarray[tuple[Any, ...], dtype[Any]]]
mars_input_requests() Iterator[dict[str, Any]]

Generate MARS input requests.

Returns:

The MARS requests.

Return type:

Iterator[DataRequest]

mars_by_levtype(levtype: str) tuple[set, set]

Get MARS parameters and levels by levtype.

Parameters:

levtype (str) – The levtype to filter by.

Returns:

The parameters and levels.

Return type:

tuple

mars_requests(*, variables: list[str], dates: list[str | datetime | int], use_grib_paramid: bool = False, always_split_time: bool = False, patch_request: Callable[[dict[str, Any]], dict[str, Any]] | None = None, dont_fail_for_missing_paramid: bool = False, **kwargs: Any) list[dict[str, Any]]

Generate MARS requests for the given variables and dates.

Parameters:
  • variables (list[str]) – The list of variables.

  • dates (list[Date]) – The list of dates.

  • use_grib_paramid (bool, optional) – Whether to use GRIB paramid, by default False.

  • always_split_time (bool, optional) – Whether to always split time, by default False.

  • patch_request (Optional[Callable], optional) – A callable to patch the request, by default None.

  • dont_fail_for_missing_paramid (bool, optional) – Whether to not fail for missing param ids, by default False.

  • **kwargs (Any) – Additional keyword arguments.

Returns:

The list of MARS requests.

Return type:

List[DataRequest]

simple_mars_requests(*, variables: list[str]) Iterator[dict[str, Any]]

Generate MARS requests for the given variables.

Parameters:

variables (list) – The list of variables.

Returns:

The MARS requests.

Return type:

Iterator[DataRequest]

Raises:

ValueError – If no variables are requested or if a variable is not found in the metadata.

report_error() None

Report an error with provenance information.

validate_environment(*, all_packages: bool = False, on_difference: Literal['warn', 'error', 'ignore', 'return'] = 'warn', exempt_packages: list[str] | None = None) bool | str

Validate environment of the checkpoint against the current environment.

Parameters:
  • all_packages (bool, optional) – Check all packages in the environment (True) or just anemoi’s (False), by default False.

  • on_difference (Literal['warn', 'error', 'ignore', 'return'], optional) – What to do on difference, by default “warn”

  • exempt_packages (list[str], optional) – List of packages to exempt from the check, by default EXEMPT_PACKAGES

Returns:

boolean if on_difference is not ‘return’, otherwise formatted text of the differences True if environment is valid, False otherwise

Return type:

Union[bool, str]

Raises:
  • RuntimeError – If found difference and on_difference is ‘error’

  • ValueError – If on_difference is not ‘warn’ or ‘error’

open_dataset(*, use_original_paths: bool | None = None, from_dataloader: str | None = None) tuple[Any, Any]

Open the dataset.

Parameters:
  • use_original_paths (bool) – Whether to use the original paths.

  • from_dataloader (str, optional) – The dataloader to use, by default None.

Returns:

The opened dataset and its arguments.

Return type:

tuple

open_dataset_args_kwargs(*, use_original_paths: bool, from_dataloader: str | None = None) tuple[Any, Any]

Get the arguments and keyword arguments for opening the dataset.

Parameters:
  • use_original_paths (bool) – Whether to use the original paths.

  • from_dataloader (str, optional) – The dataloader to use, by default None.

Returns:

The arguments and keyword arguments.

Return type:

tuple

variable_categories() dict

Get the categories of variables.

Returns:

The categories of variables.

Return type:

dict

load_supporting_array(name: str) ndarray[tuple[Any, ...], dtype[Any]]

Load a supporting array by name.

Parameters:

name (str) – The name of the supporting array.

Returns:

The supporting array.

Return type:

FloatArray

Raises:

ValueError – If the supporting array is not found.

property supporting_arrays: dict[str, ndarray[tuple[Any, ...], dtype[Any]]]

Return the supporting arrays.

property latitudes: ndarray[tuple[Any, ...], dtype[Any]] | None

Return the latitudes.

property longitudes: ndarray[tuple[Any, ...], dtype[Any]] | None

Return the longitudes.

property grid_points_mask: ndarray[tuple[Any, ...], dtype[Any]] | None

Return the grid points mask.

provenance_training() dict[str, Any]

Get the environmental configuration when trained.

Returns:

The environmental configuration.

Return type:

dict

sources(path: str) list

Get the sources from the metadata.

Parameters:

path (str) – The path to the sources.

Returns:

The list of sources.

Return type:

list

Raises:

ValueError – If not all paths were fixed.

print_variable_categories(print=<bound method Logger.info of <Logger anemoi.inference.metadata (WARNING)>>) None

Print the variable categories for debugging purposes.

patch(patch: dict) None

Patch the metadata with the given patch.

Parameters:

patch (dict) – The patch to apply.

class anemoi.inference.metadata.SingleDatasetMetadata(metadata: dict[str, Any], supporting_arrays: dict[str, ndarray[tuple[Any, ...], dtype[Any]]] = {})

Bases: Metadata

Legacy single-dataset metadata.

class anemoi.inference.metadata.MultiDatasetMetadata(metadata: dict[str, Any], supporting_arrays: dict[str, dict[str, ndarray[tuple[Any, ...], dtype[Any]]]] = {}, dataset_name='data')

Bases: Metadata

Map metadata for a multi-dataset checkpoint to a specific dataset name.

multi_dataset = True
property dataset_names: list

List of canonical dataset names.

property task: str
property timestep: timedelta

Model time stepping timestep.

property multi_step_input: int

Number of past steps needed for the initial conditions tensor.

property multi_step_output: int

Number of future steps predicted by single model forward.

property input_explicit_times: Any

Explicit times of the input steps used for the temporal downscaler.

property target_explicit_times: Any

Explicit times of the target steps used for the temporal downscaler.

property output_shape: tuple[int, int, int, int, int] | tuple[int, int, int, int]
property variable_to_input_tensor_index: MappingProxyType

Return the mapping between variable name and input tensor index.

property variable_to_output_tensor_index: MappingProxyType

Return the mapping between variable name and output tensor index.

property input_tensor_index_to_variable: MappingProxyType

Return the mapping between input tensor index and variable name.

property output_tensor_index_to_variable: MappingProxyType

Return the mapping between output tensor index and variable name.

variable_categories() dict[str, set[str]]

Get the categories of variables.

Returns:

The categories of variables.

Return type:

dict

class anemoi.inference.metadata.MetadataFactory(metadata: dict[str, ~typing.Any], supporting_arrays: dict[str, ~typing.Any] = {}, dataset_name='data', base_class=<class 'anemoi.inference.metadata.Metadata'>)

Bases: object

class anemoi.inference.metadata.SourceMetadata(parent: Metadata, name: str, metadata: dict, supporting_arrays: dict = {})

Bases: Metadata

An object that holds metadata of a source. It is only the dataset and supporting_arrays parts of the metadata. The rest is forwarded to the parent metadata object.

property latitudes: ndarray[tuple[Any, ...], dtype[Any]] | None

Return the latitudes.

property longitudes: ndarray[tuple[Any, ...], dtype[Any]] | None

Return the longitudes.

property grid_points_mask: ndarray[tuple[Any, ...], dtype[Any]] | None

Return the grid points mask.

class anemoi.inference.output.Output(context: Context, metadata: Metadata, *, variables: list[str] | None = None, post_processors: list[str | dict[str, Any]] | None = None, output_frequency: int | None = None, write_initial_state: bool | None = None)

Bases: ABC

Abstract base class for output mechanisms.

skip_variable(variable: str) bool

Check if a variable should be skipped.

Parameters:

variable (str) – The variable to check.

Returns:

True if the variable should be skipped, False otherwise.

Return type:

bool

property post_processors: list[Processor]

Return post-processors.

post_process(state: dict[str, Any]) dict[str, Any]

Apply post processors to the state.

Parameters:

state (State) – The state.

Returns:

The processed state.

Return type:

State

write_initial_state(state: dict[str, Any]) None

Write the initial state.

Parameters:

state (State) – The initial state to write.

write_state(state: dict[str, Any]) None

Write the state.

Parameters:

state (State) – The state to write.

classmethod reduce(state: dict[str, Any]) dict[str, Any]

Create a new state which is a projection of the original state on the last step in the multi-steps dimension.

Parameters:

state (State) – The original state.

Returns:

The reduced state.

Return type:

State

open(state: dict[str, Any]) None

Open the output for writing.

Parameters:

state (State) – The state to open.

close() None

Close the output.

abstractmethod write_step(state: dict[str, Any]) None

Write a step of the state.

Parameters:

state (State) – The state to write.

property write_step_zero: bool

Determine whether to write the initial state.

property output_frequency: timedelta | None

Get the output frequency.

print_summary(depth: int = 0) None

Print a summary of the output configuration.

Parameters:

depth (int, optional) – The indentation depth for the summary, by default 0.

class anemoi.inference.output.ForwardOutput(context: Context, metadata: Metadata, output: Output | Any, variables: list[str] | None = None, post_processors: list[str | dict[str, Any]] | None = None, output_frequency: int | None = None, write_initial_state: bool | None = None)

Bases: Output

Subclass of Output that forwards calls to other outputs.

Subclass from this class to implement the desired behaviour of output_frequency which should only apply to leaves.

property output_frequency: timedelta | None

Get the output frequency.

modify_state(state: dict[str, Any]) dict[str, Any]

Modify the state before writing.

Parameters:

state (State) – The state to modify.

Returns:

The modified state.

Return type:

State

open(state) None

Open the output for writing. :param state: The initial state. :type state: State

close() None

Close the output.

write_initial_state(state: dict[str, Any]) None

Write the initial step of the state.

Parameters:

state (State) – The state dictionary.

write_step(state: dict[str, Any]) None

Write a step of the state.

Parameters:

state (State) – The state to write.

print_summary(depth: int = 0) None

Print a summary of the output.

Parameters:

depth (int, optional) – The depth of the summary, by default 0.

List of precisions supported by the inference runner.

class anemoi.inference.precisions.LazyDict

Bases: object

A dictionary that lazily loads its values. So we don’t import torch at the top level, which can be slow.

get(key, default=None)
keys()
values()
items()
class anemoi.inference.processor.Processor(context: Context, metadata: Metadata)

Bases: ABC

Abstract base class for processors.

Parameters:
  • context (Context) – The context in which the processor operates.

  • metadata (Metadata) – Metadata corresponding to the dataset this processor is handling.

abstractmethod process(state: dict[str, Any]) dict[str, Any]

Process the given state.

Parameters:

state (State) – The state to be processed.

Returns:

The processed state.

Return type:

State

patch_data_request(data_request: dict[str, Any]) dict[str, Any]

Override if a processor needs to patch the data request (e.g. mars or cds).

Parameters:

data_request (DataRequest) – The data request to be patched.

Returns:

The patched data request.

Return type:

DataRequest

anemoi.inference.profiler.ProfilingLabel(label: str, use_profiler: bool) Generator[None, None, None]

Add label to function so that the profiler can recognize it, only if the use_profiler option is True.

Parameters:
  • label (str) – Name or description to identify the function.

  • use_profiler (bool) – Wrap the function with the label if True, otherwise just execute the function as it is.

Returns:

Yields to the caller.

Return type:

Generator[None, None, None]

anemoi.inference.profiler.ProfilingRunner(use_profiler: bool) Generator[None, None, None]

Perform time and memory usage profiles of the wrapped code.

Parameters:

use_profiler (bool) – Whether to profile the wrapped code (True) or not (False).

Returns:

Yields to the caller.

Return type:

Generator[None, None, None]

class anemoi.inference.protocol.MetadataProtocol(*args, **kwargs)

Bases: Protocol

Protocol for metadata objects. This will keep mypy happy.

grid: str
variables: dict[str, Any]
anemoi.inference.provenance.validate_environment(metadata: Metadata, *, all_packages: bool = False, on_difference: Literal['warn', 'error', 'ignore'] = 'warn', exempt_packages: list[str] | None = None) bool
anemoi.inference.provenance.validate_environment(metadata: Metadata, *, all_packages: bool = False, on_difference: Literal['return'] = 'return', exempt_packages: list[str] | None = None) str

Validate environment of the checkpoint against the current environment.

Parameters:
  • metadata (Metadata) – Metadata object of the checkpoint, to validate against

  • all_packages (bool, optional) – Check all packages in environment or just anemoi’s, by default False

  • on_difference (Literal['warn', 'error', 'ignore'], optional) – What to do on difference, by default “warn”

  • exempt_packages (List[str], optional) – List of packages to exempt from the check, by default EXEMPT_PACKAGES

Returns:

boolean if on_difference is not ‘return’, otherwise formatted text of the differences True if environment is valid, False otherwise

Return type:

Union[bool, str]

Raises:
  • RuntimeError – If found difference and on_difference is ‘error’

  • ValueError – If on_difference is not ‘warn’ or ‘error’

class anemoi.inference.runner.RunnerClasses(*, tensor_handler: type[TensorHandler] = <class 'anemoi.inference.tensors.TensorHandler'>, checkpoint: type[Checkpoint] = <class 'anemoi.inference.checkpoint.Checkpoint'>, metadata: type[Metadata] = <class 'anemoi.inference.metadata.Metadata'>)

Bases: BaseModel

Configurable class types used by the Runner. Child runners can override these with different classes.

model_config = {'arbitrary_types_allowed': True}

Configuration for the model, should be a dictionary conforming to [ConfigDict][pydantic.config.ConfigDict].

tensor_handler: type[TensorHandler]
checkpoint: type[Checkpoint]
metadata: type[Metadata]
class anemoi.inference.runner.Runner(config: RunConfiguration, *, classes: RunnerClasses | None = None)

Bases: Context

A runner is responsible for running a model. This class provides the default forecaster implementation with rollout.

pre_processors: dict[str, list[Processor]]
post_processors: dict[str, list[Processor]]
tensor_handlers: dict[str, TensorHandler]
prognostics_inputs: dict[str, Input]
constant_forcings_inputs: dict[str, Input]
dynamic_forcings_inputs: dict[str, Input]
boundary_forcings_inputs: dict[str, Input]
outputs: dict[str, Output]
property checkpoint: Checkpoint

returns: The checkpoint object. :rtype: Checkpoint

property device: device
run(*, input_states: dict[str, dict[str, Any]], lead_time: str | int | timedelta, return_numpy: bool = True) Generator[dict[str, dict[str, Any]], None, None]

Run the model.

Parameters:
  • input_states (dict[str, State]) – The input states for each dataset.

  • lead_time (Union[str, int, datetime.timedelta]) – The lead time.

  • return_numpy (bool, optional) – Whether to return the output state fields as numpy arrays, by default True. Otherwise, it will return torch tensors.

Returns:

The forecasted states.

Return type:

Generator[dict[str, State], None, None]

initial_constant_forcings_providers(constant_forcings_providers: list[Forcings]) list[Forcings]

Modify the constant forcings providers for the first step.

initial_dynamic_forcings_providers(dynamic_forcings_providers: list[Forcings]) list[Forcings]

Modify the dynamic forcings providers for the initial step of the inference process.

This method provides a hook to adjust the list of dynamic forcings before the first inference step is executed. By default, it returns the inputs unchanged, but subclasses can override this method to implement custom preprocessing or initialization logic.

prepare_output_state(output: Generator[dict[str, dict[str, Any]], None, None], return_numpy: bool) Generator[dict[str, dict[str, Any]], None, None]

Prepare the output state.

Parameters:
  • output (Generator[dict[str, State], None, None]) – Output state generator. Expects a dictionary of states keyed by dataset name. Expects fields in each state to be torch tensors with shape (values, variables).

  • return_numpy (bool) – Whether to return the output state fields as numpy arrays.

Yields:

Generator[dict[str, State], None, None] – The prepared output state.

property autocast: dtype | str

The autocast precision.

property model: Module

returns: The loaded model. :rtype: Any

predict_step(model: Module, input_tensors_torch: dict[str, Tensor], **kwargs: Any) dict[str, Tensor]

Predict the next step.

Parameters:
  • model (torch.nn.Module) – The model.

  • input_tensors_torch (dict[str, torch.Tensor]) – The input tensors for each dataset.

  • **kwargs (Any) – Additional keyword arguments that will be passed to the model’s predict_step method.

Returns:

The predicted step.

Return type:

torch.Tensor

forecast_stepper(start_date: datetime, lead_time: timedelta) Generator[tuple[timedelta, list[datetime], list[datetime], bool], None, None]

Generate step and date variables for the forecast autoregressive loop.

Parameters:
Returns:

  • step (datetime.timedelta) – Time delta since beginning of forecast

  • valid_date (list[datetime.datetime]) – Date of the forecast

  • next_date (list[datetime.datetime]) – Date used to prepare the next input tensor

  • is_last_step (bool) – True if it’s the last step of the forecast

forecast(lead_time: str, input_tensors_numpy: dict[str, ndarray[tuple[Any, ...], dtype[Any]]], input_states: dict[str, dict[str, Any]]) Generator[dict[str, dict[str, Any]], None, None]

Forecast the future states.

Parameters:
  • lead_time (str) – The lead time.

  • input_tensors_numpy (dict[str, FloatArray]) – The input tensors for each dataset, as numpy arrays with shape (multi_step_input, variables, values).

  • input_states (dict[str, State]) – The input states for each dataset.

Returns:

The forecasted states for each dataset.

Return type:

dict[str, State]

patch_data_request(request: dict, dataset_name: str) dict
input_state_hook(input_state: dict[str, Any]) None

Hook used by coupled runners to send the input state.

output_state_hook(state: dict[str, Any]) None

Hook used by coupled runners to send the input state.

complete_forecast_hook() None

Hook called at the end of the forecast.

has_split_input() bool
execute() None

Execute the runner.

create_output(dataset_name: str, metadata: Metadata) Output
create_input(input_type: Literal['prognostics', 'constant_forcings', 'dynamic_forcings', 'boundary_forcings'], dataset_name: str, metadata: Metadata) Input
create_pre_processors(dataset_name: str, metadata: Metadata) list[Processor]
create_post_processors(dataset_name: str, metadata: Metadata) list[Processor]
anemoi.inference.state.check_state(state: dict[str, Any], title: str = '<state>') None

Check the state for consistency.

Parameters:
  • state (dict) – The state to check.

  • title (str) – The title of the state (for logging).

Raises:

ValueError – If the state is not consistent.

anemoi.inference.state.combine_states(*states: dict[str, Any]) dict[str, Any]

Combine multiple states into one.

Parameters:

states (list) – The states to combine.

Returns:

The combined state.

Return type:

dict

anemoi.inference.state.reduce_state(state: dict[str, Any]) dict[str, Any]

Create a new state which is a projection of the original state on the last step in the multi-steps dimension.

Parameters:

state (State) – The original state.

Returns:

The reduced state.

Return type:

State

class anemoi.inference.task.Task(name: str)

Bases: ABC

Abstract base class for tasks.

Parameters:

name (str) – The name of the task.

class anemoi.inference.tensors.Kind(**attributes: Any)

Bases: object

Used for debugging purposes.

class anemoi.inference.tensors.TensorHandler(context: Runner, metadata: Metadata, constant_forcings_input: Input, dynamic_forcings_input: Input, boundary_forcings_input: Input, trace_path: str | None = None)

Bases: object

The TensorHandler is responsible for creating the input tensor for one dataset. It also handles loading the forcings and copying prognostic variables from the output tensor to the input tensor during rollout. A handler should be created per dataset. The metadata and inputs provided to the handler are specific to that dataset.

property dataset_name: str

Name of the dataset associated with the tensor handler.

prepare_input_tensor(input_state: dict[str, ~typing.Any], dtype: type[~typing.Any] | ~numpy.dtype[~typing.Any] | ~numpy._typing._dtype_like._HasDType[~numpy.dtype[~typing.Any]] | ~numpy._typing._dtype_like._HasNumPyDType[~numpy.dtype[~typing.Any]] | tuple[~typing.Any, ~typing.Any] | list[~typing.Any] | ~numpy._typing._dtype_like._DTypeDict | str = <class 'numpy.float32'>) ndarray[tuple[Any, ...], dtype[Any]]

Prepare the input tensor from the input state.

validate_input_state(input_state: dict[str, Any]) dict[str, Any]

Check that the input state has all expected entries, shapes, and check nans.

add_initial_forcings_to_input_state(input_state: dict[str, Any]) None

Add initial forcings to the input state.

Parameters:

input_state (State) – The input state.

create_constant_forcings_providers() list[Forcings]
create_dynamic_forcings_providers() list[Forcings]
create_boundary_forcings_providers() list[BoundaryForcings]
copy_prognostic_fields_to_input_tensor(input_tensor_torch: Tensor, y_pred: Tensor, check: ndarray[tuple[Any, ...], dtype[Any]]) Tensor
add_dynamic_forcings_to_input_tensor(input_tensor_torch: Tensor, state: dict[str, Any], dates: list[datetime], check: ndarray[tuple[Any, ...], dtype[Any]]) Tensor
add_boundary_forcings_to_input_tensor(input_tensor_torch: Tensor, state: dict[str, Any], dates: list[datetime], check: ndarray[tuple[Any, ...], dtype[Any]]) Tensor
create_constant_computed_forcings(variables: list[str], mask: ndarray[tuple[Any, ...], dtype[Any]]) list[Forcings]
create_dynamic_computed_forcings(variables: list[str], mask: ndarray[tuple[Any, ...], dtype[Any]]) list[Forcings]
create_constant_coupled_forcings(variables: list[str], mask: ndarray[tuple[Any, ...], dtype[Any]]) list[Forcings]
create_dynamic_coupled_forcings(variables: list[str], mask: ndarray[tuple[Any, ...], dtype[Any]]) list[Forcings]
create_boundary_forcings(variables: list[str], mask: ndarray[tuple[Any, ...], dtype[Any]]) list[Forcings]
class anemoi.inference.trace.RolloutSource

Bases: object

Represents a source of data that is a rollout.

trace_name = 'rollout'
class anemoi.inference.trace.UnknownSource

Bases: object

Represents a source of data that is unknown.

trace_name = '?'
class anemoi.inference.trace.UnchangedSource

Bases: object

Represents a source of data that is unchanged.

trace_name = 'unchanged'
class anemoi.inference.trace.InputSource(input: Any)

Bases: object

Represents a source of data that is an input.

class anemoi.inference.transport.Coupling(source: Task, target: Task, variables: list[str])

Bases: object

Represents a coupling between a source and a target with specific variables.

class anemoi.inference.transport.CouplingSend(source: Task, target: Task, variables: list[str])

Bases: Coupling

Represents a coupling send operation.

apply(task: Task, transport: Transport, *, input_state: dict[str, Any], output_state: dict[str, Any], constants: dict[str, Any], tag: int) None

Apply the coupling send operation.

Parameters:
  • task (Task) – The task to apply the coupling to.

  • transport (Transport) – The transport instance to use.

  • input_state (State) – The input state dictionary.

  • output_state (State) – The output state dictionary.

  • constants (State) – The constants dictionary.

  • tag (int) – The tag for the operation.

class anemoi.inference.transport.CouplingRecv(source: Task, target: Task, variables: list[str])

Bases: Coupling

Represents a coupling receive operation.

apply(task: Task, transport: Transport, *, input_state: dict[str, Any], output_state: dict[str, Any], constants: dict[str, Any], tag: int) None

Apply the coupling receive operation.

Parameters:
  • task (Any) – The task to apply the coupling to.

  • transport (Transport) – The transport instance to use.

  • input_state (State) – The input state dictionary.

  • output_state (State) – The output state dictionary.

  • constants (State) – The constants dictionary.

  • tag (str) – The tag for the operation.

class anemoi.inference.transport.Transport(couplings: list[dict[str, list[str]]], tasks: dict[str, Task])

Bases: ABC

Abstract base class for transport mechanisms.

abstractmethod send(sender: Task, target: Task, state: dict[str, Any], tag: int) None

Send the state from the sender to the target.

Parameters:
  • sender (Task) – The sender of the state.

  • target (Task) – The target

  • state (State) – The state dictionary to send.

  • tag (int) – The tag for the operation.

abstractmethod receive(receiver: Task, source: Task, tag: int) dict[str, Any]

Receive the state from the source to the receiver.

Parameters:
  • receiver (Task) – The receiver of the state.

  • source (Task) – The source of the state.

  • tag (int) – The tag for the operation.

Returns:

The received state dictionary.

Return type:

State

couplings(task: Task) list[Coupling]

Get the couplings for a given task.

Parameters:

task (Task) – The task to get the couplings for.

Returns:

The list of couplings for the task.

Return type:

List[Coupling]

send_state(sender: Task, target: Task, *, input_state: dict[str, Any], variables: list[str], constants: dict[str, Any], tag: int) None

Send the state from the sender to the target.

Parameters:
  • sender (Any) – The sender of the state.

  • target (Any) – The target of the state.

  • input_state (Dict[str, Any]) – The input state dictionary.

  • variables (List[str]) – The list of variables to send.

  • constants (Dict[str, Any]) – The constants dictionary.

  • tag (int) – The tag for the operation.

receive_state(receiver: Task, source: Task, *, output_state: dict[str, Any], variables: list[str], tag: int) None

Receive the state from the source to the receiver.

Parameters:
  • receiver (Any) – The receiver of the state.

  • source (Any) – The source of the state.

  • output_state (Dict[str, Any]) – The output state dictionary.

  • variables (List[str]) – The list of variables to receive.

  • tag (int) – The tag for the operation.

anemoi.inference.types.State

A dictionary that represents the state of a model.

alias of dict[str, Any]

anemoi.inference.types.DataRequest

A dictionary that represent a data request, like MARS, CDS, OpenData, …

alias of dict[str, Any]

anemoi.inference.types.Date

A date can be a string, a datetime object or an integer. It will always be converted to a datetime object.

alias of str | datetime | int

anemoi.inference.types.IntArray

A numpy array of integers.

alias of ndarray[tuple[Any, …], dtype[Any]]

anemoi.inference.types.FloatArray

A numpy array of floats.

alias of ndarray[tuple[Any, …], dtype[Any]]

anemoi.inference.types.BoolArray

A numpy array of booleans.

alias of ndarray[tuple[Any, …], dtype[Any]]

anemoi.inference.types.Shape

A tuple of integers representing the shape of an array.

alias of tuple[int, …]

anemoi.inference.types.ProcessorConfig

A str or dict of str representing a pre- or post-processor configuration.

alias of str | dict[str, Any]

class anemoi.inference.variables.Variables(metadata: Metadata)

Bases: object

classmethod default_runner_input_variables_include_exclude()

Get include/exclude lists for default runner input variables.

default_input_variables()

Select default input variables from the checkpoint.

default_input_variables_and_mask()

Select default input variables and masks from the checkpoint.

classmethod retrieved_constant_forcings_variables_include_exclude()

Get include/exclude lists for retrieved constant forcings variables.

retrieved_constant_forcings_variables()

Select retrieved constant forcings variables from the checkpoint.

retrieved_constant_forcings_variables_and_mask()

Select retrieved constant forcings variables and masks from the checkpoint.

classmethod retrieved_prognostic_variables_include_exclude()

Get include/exclude lists for retrieved prognostic variables.

retrieved_prognostic_variables()

Select retrieved prognostic variables from the checkpoint.

retrieved_prognostic_variables_and_mask()

Select retrieved prognostic variables and masks from the checkpoint.

classmethod computed_constant_forcings_variables_include_exclude()

Get include/exclude lists for computed constant forcings variables.

computed_constant_forcings_variables()

Select computed constant forcings variables from the checkpoint.

computed_constant_forcings_variables_and_mask()

Select computed constant forcings variables and masks from the checkpoint.

classmethod retrieved_dynamic_forcings_variables_include_exclude()

Get include/exclude lists for retrieved dynamic forcings variables.

retrieved_dynamic_forcings_variables()

Select retrieved dynamic forcings variables from the checkpoint.

retrieved_dynamic_forcings_variables_and_mask()

Select retrieved dynamic forcings variables and masks from the checkpoint.

classmethod input_types()

Get all input types and their include/exclude lists.

classmethod input_type_to_include_exclude(input_type: str)

Get the include/exclude dict for a given input type.