Create checkpoint migrations
Note
For more technical details about migrations, see Migration System.
Migrate
To create a new migration, run:
anemoi-models migration create MIGRATION_NAME
This will create a new migration script at the provided location that looks like:
from anemoi.models.migrations import CkptType
from anemoi.models.migrations import MigrationMetadata
metadata = MigrationMetadata(
versions={
"migration": "1.0.0",
"anemoi-models": "0.8.1",
}
)
def migrate(ckpt: CkptType) -> CkptType:
"""
Migrate the checkpoint.
Parameters
----------
ckpt : CkptType
The checkpoint dict.
Returns
-------
CkptType
The migrated checkpoint dict.
"""
return ckpt
migrate receives an old checkpoint (made before your changes), and
must return a checkpoint compatible with your changes.
Note
The metadata object is automatically generated. You should not change this part of the script.
In particular, it contains the version of the migration system. This is to allow future changes in the API but still support older migration scripts.
Migrations are only done for training checkpoints. Users are expected to re-generate the inference checkpoint once the training checkpoint is migrated.
If you migration is only related to a specific architecture, you should add a guard in the migration script. For example, related to a specific processor class:
def migrate(ckpt: CkptType) -> CkptType:
if ckpt["hyper_parameters"]["config"].model.processor._target_ == "anemoi.models.layers.processor.TransformerProcessor":
# Do stuff
...
return ckpt
Migration names have a timestamp at the start to specify their order of execution. The timestamp is decided when creating the migration script. However, it may happen that a new commit in main contains a migration script with a later timestamp than one or several of your migration scripts, which would the correct order.
The unit test test_migration_order will check whether the correct
order is preserved. If you get an error, you can run anemoi-models
migration fix-order to update the timestamps of your scripts.
Simple example
For example, if you renamed a layer x to y, you can make the following migration:
from anemoi.models.migrations import CkptType
from anemoi.models.migrations import MigrationMetadata
metadata = MigrationMetadata(
versions={
"migration": "1.0.0",
"anemoi-models": "0.8.1",
}
)
def migrate(ckpt: CkptType) -> CkptType:
"""
Migrate the checkpoint.
Parameters
----------
ckpt : CkptType
The checkpoint dict.
Returns
-------
CkptType
The migrated checkpoint dict.
"""
ckpt["state_dict"]["y"] = ckpt["state_dict"].pop("x")
return ckpt
Setup callback
Python objects are stored by reference in a pickle object. This means that if you move (or remove) a class, old checkpoints cannot be loaded.
Note
Migration scripts use a special Unpickler that obfuscate these import errors to access the migration information in the checkpoint.
The setup callbacks are functions that fix import errors. They are run
before loading the checkpoint. To add a setup callback to your script,
define the migrate_setup callback:
from anemoi.models.migrations import MigrationContext
def migrate_setup(context: MigrationContext) -> None:
"""
Migrate setup callback to be run before loading the checkpoint.
Parameters
----------
context : MigrationContext
A MigrationContext instance
"""
To generate your script with the setup callbacks, use the
--with-setup argument:
anemoi-models migration create migration-name --with-setup
The context object provides three methods to fix import errors:
context.move_attribute(start_path, end_path)to indicate that an attribute was moved fromstart_pathtoend_path.context.move_module(start_path, end_path)to indicate that a module was moved fromstart_pathtoend_path.context.delete_attribute(path)to indicate that an attribute was removed. You can use the wildcard “*” to delete any attribute in the module.
For example, if you renamed the module
anemoi.models.schemas.data_processor to
anemoi.models.schemas.data, your migration might look like:
from anemoi.models.migrations import CkptType
from anemoi.models.migrations import MigrationContext
from anemoi.models.migrations import MigrationMetadata
metadata = MigrationMetadata(
versions={
"migration": "1.0.0",
"anemoi-models": "0.8.1",
}
)
def migrate_setup(context: MigrationContext) -> None:
"""
Migrate setup callback to be run before loading the checkpoint.
Parameters
----------
context : MigrationContext
A MigrationContext instance
"""
context.move_module("anemoi.models.schemas.data_processor", "anemoi.models.schemas.data")
def migrate(ckpt: CkptType) -> CkptType:
"""
Migrate the checkpoint.
Parameters
----------
ckpt : CkptType
The checkpoint dict.
Returns
-------
CkptType
The migrated checkpoint dict.
"""
# This is also executed. You can update the checkpoint if you need to.
return ckpt
Similarly, if you moved the class NormalizerSchema from
anemoi.training.schemas.data to
anemoi.models.schemas.data_processor, the setup callback might look
like:
def migrate_setup(context: MigrationContext) -> None:
"""
Migrate setup callback to be run before loading the checkpoint.
Parameters
----------
context : MigrationContext
A MigrationContext instance
"""
context.move_attribute(
"anemoi.training.schemas.data.NormalizerSchema", "anemoi.models.schemas.data_processor.NormalizerSchema"
)
Note
The attribute can also have a different name in the final location.
Final migrations
If the modifications are too complex, and it is decided that migrating old checkpoint should not be supported, you can create a “final” migration with:
anemoi-models migration create --final MIGRATION_NAME
Full example
Here is a full example of a migration to fix PR 433
from anemoi.models.migrations import CkptType
from anemoi.models.migrations import MigrationContext
from anemoi.models.migrations import MigrationMetadata
metadata = MigrationMetadata(
versions={
"migration": "1.0.0",
"anemoi-models": "0.9.0",
}
)
def migrate_setup(context: MigrationContext) -> None:
"""
Migrate setup callback to be run before loading the checkpoint.
Parameters
----------
context : MigrationContext
A MigrationContext instance
"""
context.move_attribute(
"anemoi.training.schemas.data.NormalizerSchema", "anemoi.models.schemas.data_processor.NormalizerSchema"
)
def migrate(ckpt: CkptType) -> CkptType:
"""
Migrate the checkpoint.
Parameters
----------
ckpt : CkptType
The checkpoint dict.
Returns
-------
CkptType
The migrated checkpoint dict.
"""
return ckpt
Best practices
Here are best practices that will help you create good migration scripts.
Use a if guard to only apply scripts to specific architecture.