Source code for mtlearn.layers.checkpoint

"""Checkpoint helpers for models that contain CFP preprocessing layers."""

from __future__ import annotations

import inspect
from typing import Any, Callable

import torch

from .ConnectedFilterPreprocessingLayer import ConnectedFilterPreprocessingLayer


[docs] def collect_cfp_configs(model: torch.nn.Module) -> dict[str, dict[str, Any]]: """Return serializable configs for every primary CFP layer in ``model``.""" if not isinstance(model, torch.nn.Module): raise TypeError("model must be a torch.nn.Module.") return { name: module.get_config() for name, module in model.named_modules() if isinstance(module, ConnectedFilterPreprocessingLayer) }
[docs] def save_checkpoint( path, model: torch.nn.Module, ) -> dict[str, Any]: """Save a PyTorch checkpoint with automatically collected CFP configs. The model parameters are saved with the normal PyTorch ``state_dict``. Primary CFP layers are discovered by module name and their constructor configs are saved separately so a caller can reconstruct the model before calling ``load_state_dict``. The payload intentionally contains only model weights and CFP configs that define the meaning and shape of CFP-related weights. """ payload = { "model_state_dict": model.state_dict(), "cfp_configs": collect_cfp_configs(model), } torch.save(payload, path) return payload
def _load_torch_checkpoint(path, *, device=None, weights_only: bool = True): try: return torch.load(path, map_location=device, weights_only=weights_only) except TypeError: return torch.load(path, map_location=device) def _build_model_from_factory(factory: Callable[..., torch.nn.Module], cfp_configs: dict[str, dict[str, Any]]): try: signature = inspect.signature(factory) except (TypeError, ValueError): model = factory(cfp_configs) else: required = [ parameter for parameter in signature.parameters.values() if parameter.default is inspect.Parameter.empty and parameter.kind in { inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY, } ] if len(required) == 0: model = factory() elif len(required) == 1 and required[0].kind is not inspect.Parameter.KEYWORD_ONLY: model = factory(cfp_configs) else: raise TypeError("model_factory must accept zero arguments or one positional cfp_configs argument.") if not isinstance(model, torch.nn.Module): raise TypeError("model_factory must return a torch.nn.Module.") return model
[docs] def load_checkpoint( path, model_or_factory: torch.nn.Module | Callable[..., torch.nn.Module], *, device=None, strict: bool = True, weights_only: bool = True, ) -> tuple[torch.nn.Module, dict[str, Any]]: """Load a checkpoint saved by ``save_checkpoint``. ``model_or_factory`` can be either an already constructed module or a callable that returns a module. Factories may accept no arguments when the model constructor already hard-codes its CFP layers, or one positional ``cfp_configs`` argument when they need the saved CFP configs. """ checkpoint = _load_torch_checkpoint(path, device=device, weights_only=weights_only) cfp_configs = checkpoint.get("cfp_configs", {}) if isinstance(model_or_factory, torch.nn.Module): model = model_or_factory elif callable(model_or_factory): model = _build_model_from_factory(model_or_factory, cfp_configs) else: raise TypeError("model_or_factory must be a torch.nn.Module or callable.") if device is not None: model.to(device) model.load_state_dict(checkpoint["model_state_dict"], strict=strict) return model, checkpoint
__all__ = [ "collect_cfp_configs", "load_checkpoint", "save_checkpoint", ]