Data
This module is used to initialise the dataset (constructed using anemoi-datasets) and load in the data in to the model. It also performs a series of checks, for example, that the training dataset end date is before the start date of the validation dataset.
dataset.py contains functions which define how the dataset gets
split between the workers (worker_init_func) and how the dataset is
iterated across to produce the data batches that get fed as input into
the model (__iter__).
Users wishing to change the format of the batch input into the model
should sub-class NativeGridDataset and change the __iter__
function.
- class anemoi.training.data.dataset.NativeGridDataset(data_reader: Callable, grid_indices: type[BaseGridIndices], rollout: int = 1, multistep: int = 1, timeincrement: int = 1, shuffle: bool = True, label: str = 'generic', effective_bs: int = 1)
Bases:
IterableDatasetIterable dataset for AnemoI data on the arbitrary grids.
- property valid_date_indices: ndarray
Return valid date indices.
- A date t is valid if we can sample the sequence
(t - multistep + 1, …, t + rollout)
without missing data (if time_increment is 1).
If there are no missing dates, total number of valid ICs is dataset length minus rollout minus additional multistep inputs (if time_increment is 1).
- set_comm_group_info(global_rank: int, model_comm_group_id: int, model_comm_group_rank: int, model_comm_num_groups: int, reader_group_rank: int, reader_group_size: int) None
Set model and reader communication group information (called by DDPGroupStrategy).
- Parameters:
global_rank (int) – Global rank
model_comm_group_id (int) – Model communication group ID
model_comm_group_rank (int) – Model communication group rank
model_comm_num_groups (int) – Number of model communication groups
reader_group_rank (int) – Reader group rank
reader_group_size (int) – Reader group size
- anemoi.training.data.dataset.worker_init_func(worker_id: int) None
Configures each dataset worker process.
Calls WeatherBenchDataset.per_worker_init() on each dataset object.
- Parameters:
worker_id (int) – Worker ID
- Raises:
RuntimeError – If worker_info is None