Data

This module is used to initialise the dataset (constructed using anemoi-datasets) and load in the data in to the model. It also performs a series of checks, for example, that the training dataset end date is before the start date of the validation dataset.

The dataset files contain functions which define how the dataset gets split between the workers (worker_init_func) and how the dataset is iterated across to produce the data batches that get fed as input into the model (__iter__).

Note

Users wishing to change the format of the batch input into the model should sub-class NativeGridDataset and change the __iter__ function.

The singledataset.py file contains the NativeGridDataset class which is used for deterministic model training.

class anemoi.training.data.dataset.singledataset.NativeGridDataset(data_reader: Callable, grid_indices: type[BaseGridIndices], relative_date_indices: list, shuffle: bool = True, label: str = 'generic')

Bases: IterableDataset

Iterable dataset for AnemoI data on the arbitrary grids.

property statistics: dict

Return dataset statistics.

property metadata: dict

Return dataset metadata.

property supporting_arrays: dict

Return dataset supporting_arrays.

property name_to_index: dict

Return dataset statistics.

property resolution: dict

Return dataset resolution.

property valid_date_indices: ndarray

Return valid date indices.

A date t is valid if we can sample the elements t + i for every relative_date_index i.

set_comm_group_info(global_rank: int, model_comm_group_id: int, model_comm_group_rank: int, model_comm_num_groups: int, reader_group_rank: int, reader_group_size: int) None

Set model and reader communication group information (called by DDPGroupStrategy).

Parameters:
  • global_rank (int) – Global rank

  • model_comm_group_id (int) – Model communication group ID

  • model_comm_group_rank (int) – Model communication group rank

  • model_comm_num_groups (int) – Number of model communication groups

  • reader_group_rank (int) – Reader group rank

  • reader_group_size (int) – Reader group size

per_worker_init(n_workers: int, worker_id: int) None

Called by worker_init_func on each copy of dataset.

This initialises after the worker process has been spawned.

Parameters:
  • n_workers (int) – Number of workers

  • worker_id (int) – Worker ID

The ensdataset.py file contains the EnsNativeGridDataset class which is used for ensemble model training.

class anemoi.training.data.dataset.ensdataset.EnsNativeGridDataset(data_reader: Callable, grid_indices: type[BaseGridIndices], relative_date_indices: list, shuffle: bool = True, label: str = 'generic', ens_members_per_device: int = 1, num_gpus_per_ens: int = 1, num_gpus_per_model: int = 1)

Bases: NativeGridDataset

Iterable ensemble dataset for AnemoI data on the arbitrary grids.

property num_eda_members: int

Return number of EDA members.

property eda_flag: bool

Return whether EDA is enabled.

sample_eda_members(num_eda_members: int = 9) ndarray

Subselect EDA ensemble members assigned to the current device.

set_comm_group_info(global_rank: int, model_comm_group_id: int, model_comm_group_rank: int, model_comm_num_groups: int, ens_comm_group_id: int, ens_comm_group_rank: int, ens_comm_num_groups: int, reader_group_rank: int, reader_group_size: int) None

Set model and reader communication group information (called by DDPGroupStrategy).

Parameters:
  • global_rank (int) – Global rank

  • model_comm_group_id (int) – Model communication group ID

  • model_comm_group_rank (int) – Model communication group rank

  • model_comm_num_groups (int) – Number of model communication groups

  • ens_comm_group_id (int) – Ensemble communication group ID

  • ens_comm_group_rank (int) – Ensemble communication group rank

  • ens_comm_num_groups (int) – Number of ensemble communication groups

  • reader_group_rank (int) – Reader group rank

  • reader_group_size (int) – Reader group size

per_worker_init(n_workers: int, worker_id: int) None

Called by worker_init_func on each copy of dataset.

This initialises after the worker process has been spawned.

Parameters:
  • n_workers (int) – Number of workers

  • worker_id (int) – Worker ID